db_helper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import pandas as pd
  2. from sqlalchemy import create_engine, text
  3. from sqlalchemy.exc import SQLAlchemyError
  4. from crossborder.utils.crypto_utils import AESCryptor
  5. from crossborder.utils.log import get_logger
  6. log = get_logger(__name__)
  7. DB_CONFIG = {
  8. 'host': '10.130.75.149',
  9. 'port': 3307,
  10. 'user': 'yto_crm',
  11. 'password': 'ENC(Fl9g4899OmVYddM42Rt2fA==:sDy1QG/7bmx/iHo4xEOBGQ==)',
  12. 'database': 'crm_uat',
  13. 'charset': 'utf8mb4'
  14. }
  15. cryptor = AESCryptor("uat_ff419620e7047a3c372e2513c5a2b9a5")
  16. def get_decrypted_password():
  17. encrypted_pass = DB_CONFIG['password']
  18. if encrypted_pass.startswith("ENC("):
  19. try:
  20. return cryptor.decrypt(encrypted_pass)
  21. except Exception as e:
  22. log.error(f"密码解密失败: {str(e)}")
  23. raise
  24. return encrypted_pass
  25. class DBHelper:
  26. def __init__(self):
  27. db_config = DB_CONFIG.copy()
  28. db_config['password'] = get_decrypted_password()
  29. self.engine = create_engine(
  30. f'mysql+pymysql://{db_config["user"]}:{db_config["password"]}@{db_config["host"]}:{db_config["port"]}/{db_config["database"]}?charset={db_config["charset"]}',
  31. pool_size=5,
  32. max_overflow=10
  33. )
  34. def get_commodity_id(self, name):
  35. """获取商品编码对应的分类ID[1,3](@ref)"""
  36. with self.engine.connect() as conn:
  37. result = conn.execute(
  38. text("SELECT id FROM t_yujin_crossborder_prov_commodity_category WHERE commodity_name = :name"),
  39. {'name': name}
  40. ).fetchone()
  41. return result[0] if result else None
  42. def bulk_insert(self, df, table_name, conflict_columns=None, update_columns=None):
  43. """
  44. 增强版批量插入(支持覆盖更新)
  45. :param df: 要插入的DataFrame
  46. :param table_name: 目标表名
  47. :param conflict_columns: 冲突检测字段列表
  48. :param update_columns: 需要更新的字段列表
  49. """
  50. if df.empty:
  51. log.info("空数据集,跳过插入")
  52. return
  53. # 生成带参数的SQL模板
  54. columns = ', '.join(df.columns)
  55. placeholders = ', '.join([f":{col}" for col in df.columns])
  56. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  57. # 添加ON DUPLICATE KEY UPDATE(MySQL语法)
  58. if conflict_columns and update_columns:
  59. # 1. 处理用户指定的更新字段
  60. update_clauses = [f"{col}=VALUES({col})" for col in update_columns]
  61. # 2. 强制添加create_time=NOW()
  62. update_clauses.append("create_time = NOW()") # 新增
  63. # 3. 合并所有更新条件
  64. update_set = ', '.join(update_clauses)
  65. sql += f" ON DUPLICATE KEY UPDATE {update_set}"
  66. # 转换数据为字典列表格式
  67. data = df.to_dict(orient='records')
  68. # print("data:", data)
  69. try:
  70. with self.engine.connect() as conn:
  71. # 显式开启事务
  72. with conn.begin():
  73. # 使用text()包装SQL语句
  74. stmt = text(sql)
  75. # 批量执行
  76. conn.execute(stmt, data)
  77. log.info(f"成功插入/更新 {len(df)} 行到 {table_name}")
  78. except Exception as e:
  79. log.error(f"数据库操作失败: {str(e)}")
  80. raise
  81. def update_prov_yoy(self, prov_name):
  82. """
  83. 完整更新山东省同比数据(包含新旧数据处理)
  84. """
  85. try:
  86. # 步骤1:清理旧数据
  87. cleared = self.clear_old_prov_yoy(prov_name)
  88. # 步骤2:计算新数据
  89. updated = self._update_prov_new_yoy(prov_name)
  90. log.info(f"{prov_name}同比处理完成 | 清零:{cleared} 更新:{updated}")
  91. return {'cleared': cleared, 'updated': updated}
  92. except Exception as e:
  93. log.error(f"{prov_name}数据处理失败", exc_info=True)
  94. raise
  95. def clear_old_prov_yoy(self, prov_name):
  96. """
  97. 清理指定省份2024年前数据的同比指标
  98. """
  99. clear_sql = text("""
  100. UPDATE t_yujin_crossborder_prov_region_trade
  101. SET yoy_import_export = null,
  102. yoy_export = null,
  103. yoy_import = null
  104. WHERE prov_name = :prov_name
  105. AND crossborder_year_month < '2024-01'
  106. AND (yoy_import_export != 0
  107. OR yoy_export != 0
  108. OR yoy_import != 0) -- 优化:仅更新非零记录
  109. """)
  110. try:
  111. with self.engine.begin() as conn:
  112. result = conn.execute(clear_sql, {'prov_name': prov_name})
  113. log.info(f"{prov_name}旧数据清零记录数: {result.rowcount}")
  114. return result.rowcount
  115. except Exception as e:
  116. log.error(f"旧数据清零失败: {str(e)}")
  117. raise
  118. def _update_prov_new_yoy(self,prov_name):
  119. """
  120. 更新2024年及之后的省份城市同比数据
  121. """
  122. update_sql = text("""
  123. UPDATE t_yujin_crossborder_prov_region_trade AS curr
  124. INNER JOIN t_yujin_crossborder_prov_region_trade AS prev
  125. ON curr.city_code = prev.city_code
  126. AND prev.crossborder_year_month = DATE_FORMAT(
  127. DATE_SUB(
  128. STR_TO_DATE(CONCAT(curr.crossborder_year_month, '-01'), '%Y-%m-%d'),
  129. INTERVAL 1 YEAR
  130. ),
  131. '%Y-%m'
  132. )
  133. SET
  134. curr.yoy_import_export = COALESCE (
  135. TRUNCATE((curr.monthly_total - prev.monthly_total) / NULLIF (prev.monthly_total, 0) * 100, 4),
  136. 0.0000
  137. ),
  138. curr.yoy_import = COALESCE (
  139. TRUNCATE((curr.monthly_import - prev.monthly_import) / NULLIF (prev.monthly_import, 0) * 100, 4),
  140. 0.0000
  141. ),
  142. curr.yoy_export = COALESCE (
  143. TRUNCATE((curr.monthly_export - prev.monthly_export) / NULLIF (prev.monthly_export, 0) * 100, 4),
  144. 0.0000
  145. )
  146. WHERE
  147. curr.prov_name = :prov_name
  148. AND curr.crossborder_year_month >= '2024-01'
  149. AND prev.monthly_total IS NOT NULL
  150. """)
  151. with self.engine.begin() as conn:
  152. result = conn.execute(update_sql, {'prov_name': prov_name})
  153. log.info(f"{prov_name}新数据更新数: {result.rowcount}")
  154. return result.rowcount
  155. def query(self, sql, params=None, return_df=True):
  156. try:
  157. with self.engine.connect() as conn:
  158. if return_df:
  159. # 替代方法:使用 SQLAlchemy 结果代理直接创建 DataFrame
  160. result_proxy = conn.execute(text(sql), params or {})
  161. # 更健壮的方式获取列名
  162. columns = [col_desc[0] for col_desc in result_proxy.cursor.description]
  163. # 获取所有数据
  164. data = result_proxy.fetchall()
  165. # 手动创建 DataFrame
  166. df = pd.DataFrame(data, columns=columns)
  167. log.info(f"查询成功,返回 {len(df)} 条记录")
  168. return df
  169. else:
  170. result = conn.execute(text(sql), params or {}).fetchall()
  171. log.info(f"查询成功,返回 {len(result)} 条记录")
  172. return result
  173. except Exception as e:
  174. log.error(f"查询失败: {str(e)}")
  175. # 添加详细信息日志
  176. log.error(f"SQL: {sql}")
  177. log.error(f"Params: {params}")
  178. raise
  179. def execute_sql_with_params(self, sql: str, params_list: list):
  180. """
  181. 执行带参数的SQL语句(支持批量插入/更新)
  182. :param sql: 参数化的SQL语句(如含%s、%s等)
  183. :param params_list: 参数列表,每个元素是一个tuple或dict(根据SQL风格而定)
  184. :return: 受影响行数
  185. """
  186. try:
  187. with self.engine.connect() as conn:
  188. with conn.begin():
  189. # 使用text()包装原始SQL
  190. stmt = text(sql)
  191. # 判断是否为多组参数(批量插入)
  192. if isinstance(params_list[0], (list, tuple)):
  193. result = conn.execute(stmt, params_list)
  194. else:
  195. result = conn.execute(stmt, params_list)
  196. affected_rows = result.rowcount
  197. log.info(f"成功执行SQL,受影响行数:{affected_rows}")
  198. return affected_rows
  199. except Exception as e:
  200. log.error(f"SQL执行失败: {str(e)}")
  201. raise
  202. def get_code_exist(self, year_month, prov_code, is_city=True):
  203. """
  204. 检查指定月份和省份在表中是否存在
  205. 参数:
  206. year_month: 年月字符串 (格式: YYYY-MM)
  207. prov_code: 省份代码
  208. 返回:
  209. 匹配的记录数量
  210. """
  211. # 根据省份代码确定表名(示例逻辑,实际需替换为您的表名规则)
  212. table_name = self.get_table_name_by_province(prov_code,is_city)
  213. # 安全检查:防止SQL注入
  214. if not table_name.isidentifier():
  215. raise ValueError(f"非法表名: {table_name}")
  216. # 使用参数化查询防止SQL注入
  217. query = text(f"""
  218. SELECT COUNT(*) FROM `{table_name}`
  219. WHERE crossborder_year_month = :year_month
  220. AND prov_code = :prov_code
  221. """)
  222. try:
  223. with self.engine.connect() as connection:
  224. result = connection.execute(
  225. query,
  226. {"year_month": year_month, "prov_code": prov_code}
  227. ).scalar()
  228. return result or 0
  229. except SQLAlchemyError as e:
  230. # 实际项目中应使用更详细的日志记录
  231. log.error(f"查询错误: {str(e)}")
  232. return -1 # 表示查询出错
  233. # 辅助函数:根据省份代码获取表名(示例实现,按需修改)
  234. def get_table_name_by_province(self, prov_code, is_city=True):
  235. """
  236. 根据省份代码和数据类型返回对应表名
  237. """
  238. # 主要表名映射规则
  239. # 350000=福建,370000=山东,410000=河南,440000=广东
  240. table_mapping = {
  241. "350000": "t_yujin_crossborder_prov_region_trade",
  242. "370000": "t_yujin_crossborder_prov_region_trade",
  243. "410000": "t_yujin_crossborder_prov_commodity_trade",
  244. "440000": self.get_guangdong_table(is_city) # 特殊处理广东省
  245. }
  246. if prov_code not in table_mapping:
  247. raise ValueError(f"不支持省份代码: {prov_code}")
  248. return table_mapping[prov_code]
  249. # 新增方法:处理广东省的特殊情况
  250. def get_guangdong_table(self, is_city):
  251. """
  252. 根据数据类型返回广东省对应的表名
  253. """
  254. if is_city:
  255. return "t_yujin_crossborder_prov_region_trade"
  256. else:
  257. return "t_yujin_crossborder_prov_commodity_trade"