123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- import pandas as pd
- from sqlalchemy import create_engine, text
- from sqlalchemy.exc import SQLAlchemyError
- from crossborder.utils.constants import CUSTOMS_CITY_MAPPING
- from crossborder.utils.crypto_utils import AESCryptor
- from crossborder.utils.log import get_logger
- log = get_logger(__name__)
- DB_CONFIG = {
- 'host': '10.130.75.149',
- 'port': 3307,
- 'user': 'yto_crm',
- 'password': 'ENC(Fl9g4899OmVYddM42Rt2fA==:sDy1QG/7bmx/iHo4xEOBGQ==)',
- 'database': 'crm_uat',
- 'charset': 'utf8mb4'
- }
- cryptor = AESCryptor("uat_ff419620e7047a3c372e2513c5a2b9a5")
- def get_decrypted_password():
- encrypted_pass = DB_CONFIG['password']
- if encrypted_pass.startswith("ENC("):
- try:
- return cryptor.decrypt(encrypted_pass)
- except Exception as e:
- log.error(f"密码解密失败: {str(e)}")
- raise
- return encrypted_pass
- class DBHelper:
- def __init__(self):
- db_config = DB_CONFIG.copy()
- db_config['password'] = get_decrypted_password()
- self.engine = create_engine(
- f'mysql+pymysql://{db_config["user"]}:{db_config["password"]}@{db_config["host"]}:{db_config["port"]}/{db_config["database"]}?charset={db_config["charset"]}',
- pool_size=5,
- max_overflow=10
- )
- def get_commodity_id(self, name):
- """获取商品编码对应的分类ID[1,3](@ref)"""
- with self.engine.connect() as conn:
- result = conn.execute(
- text("SELECT id FROM t_yujin_crossborder_prov_commodity_category WHERE commodity_name = :name"),
- {'name': name}
- ).fetchone()
- return result[0] if result else None
- def bulk_insert(self, df, table_name, conflict_columns=None, update_columns=None):
- """
- 增强版批量插入(支持覆盖更新)
- :param df: 要插入的DataFrame
- :param table_name: 目标表名
- :param conflict_columns: 冲突检测字段列表
- :param update_columns: 需要更新的字段列表
- """
- if df.empty:
- log.info("空数据集,跳过插入")
- return
- # 生成带参数的SQL模板
- columns = ', '.join(df.columns)
- placeholders = ', '.join([f":{col}" for col in df.columns])
- sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
- # 添加ON DUPLICATE KEY UPDATE(MySQL语法)
- if conflict_columns and update_columns:
- # 1. 处理用户指定的更新字段
- update_clauses = [f"{col}=VALUES({col})" for col in update_columns]
- # 2. 强制添加create_time=NOW()
- update_clauses.append("create_time = NOW()") # 新增
- # 3. 合并所有更新条件
- update_set = ', '.join(update_clauses)
- sql += f" ON DUPLICATE KEY UPDATE {update_set}"
- # 转换数据为字典列表格式
- data = df.to_dict(orient='records')
- # print("data:", data)
- try:
- with self.engine.connect() as conn:
- # 显式开启事务
- with conn.begin():
- # 使用text()包装SQL语句
- stmt = text(sql)
- # 批量执行
- conn.execute(stmt, data)
- log.info(f"成功插入/更新 {len(df)} 行到 {table_name}")
- except Exception as e:
- log.error(f"数据库操作失败: {str(e)}")
- raise
- def update_prov_yoy(self, prov_name):
- """
- 完整更新山东省同比数据(包含新旧数据处理)
- """
- try:
- # 步骤1:清理旧数据
- cleared = self.clear_old_prov_yoy(prov_name)
- # 步骤2:计算新数据
- updated = self._update_prov_new_yoy(prov_name)
- log.info(f"{prov_name}同比处理完成 | 清零:{cleared} 更新:{updated}")
- return {'cleared': cleared, 'updated': updated}
- except Exception as e:
- log.error(f"{prov_name}数据处理失败", exc_info=True)
- raise
- def clear_old_prov_yoy(self, prov_name):
- """
- 清理指定省份2024年前数据的同比指标
- """
- clear_sql = text("""
- UPDATE t_yujin_crossborder_prov_region_trade
- SET yoy_import_export = null,
- yoy_export = null,
- yoy_import = null
- WHERE prov_name = :prov_name
- AND crossborder_year_month < '2024-01'
- AND (yoy_import_export != 0
- OR yoy_export != 0
- OR yoy_import != 0) -- 优化:仅更新非零记录
- """)
- try:
- with self.engine.begin() as conn:
- result = conn.execute(clear_sql, {'prov_name': prov_name})
- log.info(f"{prov_name}旧数据清零记录数: {result.rowcount}")
- return result.rowcount
- except Exception as e:
- log.error(f"旧数据清零失败: {str(e)}")
- raise
- def _update_prov_new_yoy(self,prov_name):
- """
- 更新2024年及之后的省份城市同比数据
- """
- update_sql = text("""
- UPDATE t_yujin_crossborder_prov_region_trade AS curr
- INNER JOIN t_yujin_crossborder_prov_region_trade AS prev
- ON curr.city_code = prev.city_code
- AND prev.crossborder_year_month = DATE_FORMAT(
- DATE_SUB(
- STR_TO_DATE(CONCAT(curr.crossborder_year_month, '-01'), '%Y-%m-%d'),
- INTERVAL 1 YEAR
- ),
- '%Y-%m'
- )
- SET
- curr.yoy_import_export = COALESCE (
- TRUNCATE((curr.monthly_total - prev.monthly_total) / NULLIF (prev.monthly_total, 0) * 100, 4),
- 0.0000
- ),
- curr.yoy_import = COALESCE (
- TRUNCATE((curr.monthly_import - prev.monthly_import) / NULLIF (prev.monthly_import, 0) * 100, 4),
- 0.0000
- ),
- curr.yoy_export = COALESCE (
- TRUNCATE((curr.monthly_export - prev.monthly_export) / NULLIF (prev.monthly_export, 0) * 100, 4),
- 0.0000
- )
- WHERE
- curr.prov_name = :prov_name
- AND curr.crossborder_year_month >= '2024-01'
- AND prev.monthly_total IS NOT NULL
- """)
- with self.engine.begin() as conn:
- result = conn.execute(update_sql, {'prov_name': prov_name})
- log.info(f"{prov_name}新数据更新数: {result.rowcount}")
- return result.rowcount
- def query(self, sql, params=None, return_df=True):
- try:
- with self.engine.connect() as conn:
- if return_df:
- # 替代方法:使用 SQLAlchemy 结果代理直接创建 DataFrame
- result_proxy = conn.execute(text(sql), params or {})
- # 更健壮的方式获取列名
- columns = [col_desc[0] for col_desc in result_proxy.cursor.description]
- # 获取所有数据
- data = result_proxy.fetchall()
- # 手动创建 DataFrame
- df = pd.DataFrame(data, columns=columns)
- log.info(f"查询成功,返回 {len(df)} 条记录")
- return df
- else:
- result = conn.execute(text(sql), params or {}).fetchall()
- log.info(f"查询成功,返回 {len(result)} 条记录")
- return result
- except Exception as e:
- log.error(f"查询失败: {str(e)}")
- # 添加详细信息日志
- log.error(f"SQL: {sql}")
- log.error(f"Params: {params}")
- raise
- def execute_sql_with_params(self, sql: str, params_list: list):
- """
- 执行带参数的SQL语句(支持批量插入/更新)
- :param sql: 参数化的SQL语句(如含%s、%s等)
- :param params_list: 参数列表,每个元素是一个tuple或dict(根据SQL风格而定)
- :return: 受影响行数
- """
- try:
- with self.engine.connect() as conn:
- with conn.begin():
- # 使用text()包装原始SQL
- stmt = text(sql)
- # 判断是否为多组参数(批量插入)
- if isinstance(params_list[0], (list, tuple)):
- result = conn.execute(stmt, params_list)
- else:
- result = conn.execute(stmt, params_list)
- affected_rows = result.rowcount
- log.info(f"成功执行SQL,受影响行数:{affected_rows}")
- return affected_rows
- except Exception as e:
- log.error(f"SQL执行失败: {str(e)}")
- raise
- def get_code_exist(self, year_month, prov_code, is_city=True, customs_name=None, city_names=None):
- """
- 检查指定月份和地区在表中是否存在记录
- 参数:
- year_month: 年月字符串 (格式: YYYY-MM)
- prov_code: 省份代码
- is_city: 是否为城市级数据
- customs_name: 海关名称(可选)
- city_names: 城市名称列表(可选)
- 返回:
- 匹配的记录数量
- """
- # 获取表名
- table_name = self.get_table_name_by_province(prov_code, is_city)
- # 构建基础查询
- base_query = f"SELECT COUNT(*) FROM `{table_name}` WHERE crossborder_year_month = :year_month"
- params = {"year_month": year_month}
- # 添加地区条件
- conditions = []
- if customs_name and customs_name in CUSTOMS_CITY_MAPPING:
- # 根据海关获取对应的城市
- cities = CUSTOMS_CITY_MAPPING[customs_name]
- # 添加城市条件
- conditions.append(f"city_name IN :cities")
- params["cities"] = tuple(cities)
- elif city_names:
- # 直接使用提供的城市列表
- conditions.append(f"city_name IN :cities")
- params["cities"] = tuple(city_names)
- elif is_city:
- # 默认添加省份代码条件
- conditions.append(f"prov_code = :prov_code")
- params["prov_code"] = prov_code
- # 组合查询条件
- if conditions:
- base_query += " AND " + " AND ".join(conditions)
- try:
- query = text(base_query)
- with self.engine.connect() as connection:
- result = connection.execute(query, params).scalar()
- return result or 0
- except SQLAlchemyError as e:
- log.error(f"查询错误: {str(e)}")
- return -1
- # 辅助函数:根据省份代码获取表名(示例实现,按需修改)
- def get_table_name_by_province(self, prov_code, is_city=True):
- """
- 根据省份代码和数据类型返回对应表名
- """
- # 主要表名映射规则
- # 350000=福建,370000=山东,410000=河南,440000=广东
- table_mapping = {
- "350000": "t_yujin_crossborder_prov_region_trade",
- "370000": "t_yujin_crossborder_prov_region_trade",
- "410000": "t_yujin_crossborder_prov_commodity_trade",
- "440000": self.get_guangdong_table(is_city) # 特殊处理广东省
- }
- if prov_code not in table_mapping:
- raise ValueError(f"不支持省份代码: {prov_code}")
- return table_mapping[prov_code]
- # 新增方法:处理广东省的特殊情况
- def get_guangdong_table(self, is_city):
- """
- 根据数据类型返回广东省对应的表名
- """
- if is_city:
- return "t_yujin_crossborder_prov_region_trade"
- else:
- return "t_yujin_crossborder_prov_commodity_trade"
- def get_total_info_exist(self, year_month):
- query = text(f"""
- SELECT COUNT(*) FROM `t_yujin_crossborder_region_trade`
- WHERE `year_month` = :year_month
- """)
- try:
- with self.engine.connect() as connection:
- result = connection.execute(
- query,
- {"year_month": year_month}
- ).scalar()
- return result or 0
- except SQLAlchemyError as e:
- log.error(f"查询错误: {str(e)}")
- return -1 # 表示查询出错
|