db_helper.py 11 KB


  1. import pandas as pd
  2. from sqlalchemy import create_engine, text
  3. from crossborder.utils.crypto_utils import AESCryptor
  4. from crossborder.utils.log import log
  5. DB_CONFIG = {
  6. 'host': '10.130.75.149',
  7. 'port': 3307,
  8. 'user': 'yto_crm',
  9. 'password': 'ENC(Fl9g4899OmVYddM42Rt2fA==:sDy1QG/7bmx/iHo4xEOBGQ==)',
  10. 'database': 'crm_uat',
  11. 'charset': 'utf8mb4'
  12. }
  13. cryptor = AESCryptor("uat_ff419620e7047a3c372e2513c5a2b9a5")
  14. def get_decrypted_password():
  15. encrypted_pass = DB_CONFIG['password']
  16. if encrypted_pass.startswith("ENC("):
  17. try:
  18. return cryptor.decrypt(encrypted_pass)
  19. except Exception as e:
  20. log.error(f"密码解密失败: {str(e)}")
  21. raise
  22. return encrypted_pass
  23. class DBHelper:
  24. def __init__(self):
  25. db_config = DB_CONFIG.copy()
  26. db_config['password'] = get_decrypted_password()
  27. self.engine = create_engine(
  28. f'mysql+pymysql://{db_config["user"]}:{db_config["password"]}@{db_config["host"]}:{db_config["port"]}/{db_config["database"]}?charset={db_config["charset"]}',
  29. pool_size=5,
  30. max_overflow=10
  31. )
  32. def get_commodity_id(self, name):
  33. """获取商品编码对应的分类ID[1,3](@ref)"""
  34. with self.engine.connect() as conn:
  35. result = conn.execute(
  36. text("SELECT id FROM t_yujin_crossborder_prov_commodity_category WHERE commodity_name = :name"),
  37. {'name': name}
  38. ).fetchone()
  39. return result[0] if result else None
  40. def bulk_insert(self, df, table_name, conflict_columns=None, update_columns=None):
  41. """
  42. 增强版批量插入(支持覆盖更新)
  43. :param df: 要插入的DataFrame
  44. :param table_name: 目标表名
  45. :param conflict_columns: 冲突检测字段列表
  46. :param update_columns: 需要更新的字段列表
  47. """
  48. if df.empty:
  49. log.info("空数据集,跳过插入")
  50. return
  51. # 生成带参数的SQL模板
  52. columns = ', '.join(df.columns)
  53. placeholders = ', '.join([f":{col}" for col in df.columns])
  54. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  55. # 添加ON DUPLICATE KEY UPDATE(MySQL语法)
  56. if conflict_columns and update_columns:
  57. # 1. 处理用户指定的更新字段
  58. update_clauses = [f"{col}=VALUES({col})" for col in update_columns]
  59. # 2. 强制添加create_time=NOW()
  60. update_clauses.append("create_time = NOW()") # 新增
  61. # 3. 合并所有更新条件
  62. update_set = ', '.join(update_clauses)
  63. sql += f" ON DUPLICATE KEY UPDATE {update_set}"
  64. # 转换数据为字典列表格式
  65. data = df.to_dict(orient='records')
  66. # print("data:", data)
  67. try:
  68. with self.engine.connect() as conn:
  69. # 显式开启事务
  70. with conn.begin():
  71. # 使用text()包装SQL语句
  72. stmt = text(sql)
  73. # 批量执行
  74. conn.execute(stmt, data)
  75. log.info(f"成功插入/更新 {len(df)} 行到 {table_name}")
  76. except Exception as e:
  77. log.error(f"数据库操作失败: {str(e)}")
  78. raise
  79. def update_january_yoy(self, prov_name='福建省'):
  80. """
  81. 更新指定省份1月份同比数据
  82. :param prov_name: 省份名称,默认为福建省
  83. """
  84. update_sql = text("""
  85. UPDATE t_yujin_crossborder_prov_region_trade AS curr
  86. INNER JOIN t_yujin_crossborder_prov_region_trade AS prev
  87. ON curr.city_code = prev.city_code
  88. AND prev.crossborder_year_month = DATE_FORMAT(
  89. DATE_SUB(
  90. STR_TO_DATE(CONCAT(curr.crossborder_year_month, '-01'), '%Y-%m-%d'),
  91. INTERVAL 1 YEAR
  92. ),
  93. '%Y-01'
  94. )
  95. SET
  96. curr.yoy_import_export = COALESCE (
  97. TRUNCATE((curr.monthly_total - prev.monthly_total) / NULLIF (prev.monthly_total, 0) * 100, 4),
  98. 0.0000
  99. ),
  100. curr.yoy_import = COALESCE (
  101. TRUNCATE((curr.monthly_import - prev.monthly_import) / NULLIF (prev.monthly_import, 0) * 100, 4),
  102. 0.0000
  103. ),
  104. curr.yoy_export = COALESCE (
  105. TRUNCATE((curr.monthly_export - prev.monthly_export) / NULLIF (prev.monthly_export, 0) * 100, 4),
  106. 0.0000
  107. )
  108. WHERE
  109. curr.prov_name = :prov_name
  110. AND curr.crossborder_year_month LIKE '%-01'
  111. AND curr.crossborder_year_month
  112. > '2023-01'
  113. """)
  114. try:
  115. with self.engine.begin() as conn:
  116. result = conn.execute(update_sql, {'prov_name': prov_name})
  117. log.info(f"Updated {result.rowcount} rows for {prov_name}")
  118. return result.rowcount
  119. except Exception as e:
  120. log.error(f"Update failed: {str(e)}")
  121. raise RuntimeError(f"同比数据更新失败: {str(e)}") from e
  122. def update_prov_yoy(self, prov_name):
  123. """
  124. 完整更新山东省同比数据(包含新旧数据处理)
  125. """
  126. try:
  127. # 步骤1:清理旧数据
  128. cleared = self.clear_old_prov_yoy(prov_name)
  129. # 步骤2:计算新数据
  130. updated = self._update_prov_new_yoy(prov_name)
  131. log.info(f"{prov_name}同比处理完成 | 清零:{cleared} 更新:{updated}")
  132. return {'cleared': cleared, 'updated': updated}
  133. except Exception as e:
  134. log.error(f"{prov_name}数据处理失败", exc_info=True)
  135. raise
  136. def clear_old_prov_yoy(self, prov_name):
  137. """
  138. 清理指定省份2024年前数据的同比指标
  139. """
  140. clear_sql = text("""
  141. UPDATE t_yujin_crossborder_prov_region_trade
  142. SET yoy_import_export = null,
  143. yoy_export = null,
  144. yoy_import = null
  145. WHERE prov_name = :prov_name
  146. AND crossborder_year_month < '2024-01'
  147. AND (yoy_import_export != 0
  148. OR yoy_export != 0
  149. OR yoy_import != 0) -- 优化:仅更新非零记录
  150. """)
  151. try:
  152. with self.engine.begin() as conn:
  153. result = conn.execute(clear_sql, {'prov_name': prov_name})
  154. log.info(f"{prov_name}旧数据清零记录数: {result.rowcount}")
  155. return result.rowcount
  156. except Exception as e:
  157. log.error(f"旧数据清零失败: {str(e)}")
  158. raise
  159. def _update_prov_new_yoy(self,prov_name):
  160. """
  161. 更新2024年及之后的省份城市同比数据
  162. """
  163. update_sql = text("""
  164. UPDATE t_yujin_crossborder_prov_region_trade AS curr
  165. INNER JOIN t_yujin_crossborder_prov_region_trade AS prev
  166. ON curr.city_code = prev.city_code
  167. AND prev.crossborder_year_month = DATE_FORMAT(
  168. DATE_SUB(
  169. STR_TO_DATE(CONCAT(curr.crossborder_year_month, '-01'), '%Y-%m-%d'),
  170. INTERVAL 1 YEAR
  171. ),
  172. '%Y-%m'
  173. )
  174. SET
  175. curr.yoy_import_export = COALESCE (
  176. TRUNCATE((curr.monthly_total - prev.monthly_total) / NULLIF (prev.monthly_total, 0) * 100, 4),
  177. 0.0000
  178. ),
  179. curr.yoy_import = COALESCE (
  180. TRUNCATE((curr.monthly_import - prev.monthly_import) / NULLIF (prev.monthly_import, 0) * 100, 4),
  181. 0.0000
  182. ),
  183. curr.yoy_export = COALESCE (
  184. TRUNCATE((curr.monthly_export - prev.monthly_export) / NULLIF (prev.monthly_export, 0) * 100, 4),
  185. 0.0000
  186. )
  187. WHERE
  188. curr.prov_name = :prov_name
  189. AND curr.crossborder_year_month >= '2024-01'
  190. AND prev.monthly_total IS NOT NULL
  191. """)
  192. with self.engine.begin() as conn:
  193. result = conn.execute(update_sql, {'prov_name': prov_name})
  194. log.info(f"{prov_name}新数据更新数: {result.rowcount}")
  195. return result.rowcount
  196. def query(self, sql, params=None, return_df=True):
  197. """
  198. 执行带参数的SQL语句(支持批量插入/更新)
  199. :param sql: 参数化的SQL语句(如含%s、%s等)
  200. :param params_list: 参数列表,每个元素是一个tuple或dict(根据SQL风格而定)
  201. :return: 受影响行数
  202. """
  203. try:
  204. with self.engine.connect() as conn:
  205. if return_df:
  206. # 使用pandas直接读取为DataFrame
  207. result = pd.read_sql(sql, conn, params=params)
  208. log.info(f"查询成功,返回 {len(result)} 条记录")
  209. return result
  210. else:
  211. # 返回原始结果
  212. result = conn.execute(sql, params or {}).fetchall()
  213. log.info(f"查询成功,返回 {len(result)} 条记录")
  214. return result
  215. except Exception as e:
  216. log.error(f"查询失败: {str(e)}")
  217. raise
  218. def execute_sql_with_params(self, sql: str, params_list: list):
  219. """
  220. 执行带参数的SQL语句(支持批量插入/更新)
  221. :param sql: 参数化的SQL语句(如含%s、%s等)
  222. :param params_list: 参数列表,每个元素是一个tuple或dict(根据SQL风格而定)
  223. :return: 受影响行数
  224. """
  225. try:
  226. with self.engine.connect() as conn:
  227. with conn.begin():
  228. # 使用text()包装原始SQL
  229. stmt = text(sql)
  230. # 判断是否为多组参数(批量插入)
  231. if isinstance(params_list[0], (list, tuple)):
  232. result = conn.execute(stmt, params_list)
  233. else:
  234. result = conn.execute(stmt, params_list)
  235. affected_rows = result.rowcount
  236. log.info(f"成功执行SQL,受影响行数:{affected_rows}")
  237. return affected_rows
  238. except Exception as e:
  239. log.error(f"SQL执行失败: {str(e)}")
  240. raise