db_helper.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import pandas as pd
  2. from sqlalchemy import create_engine, text
  3. from sqlalchemy.exc import SQLAlchemyError
  4. from crossborder.utils.constants import CUSTOMS_CITY_MAPPING
  5. from crossborder.utils.crypto_utils import AESCryptor
  6. from crossborder.utils.log import get_logger
  7. log = get_logger(__name__)
  8. DB_CONFIG = {
  9. 'host': '10.130.75.149',
  10. 'port': 3307,
  11. 'user': 'yto_crm',
  12. 'password': 'ENC(Fl9g4899OmVYddM42Rt2fA==:sDy1QG/7bmx/iHo4xEOBGQ==)',
  13. 'database': 'crm_uat',
  14. 'charset': 'utf8mb4'
  15. }
  16. cryptor = AESCryptor("uat_ff419620e7047a3c372e2513c5a2b9a5")
  17. def get_decrypted_password():
  18. encrypted_pass = DB_CONFIG['password']
  19. if encrypted_pass.startswith("ENC("):
  20. try:
  21. return cryptor.decrypt(encrypted_pass)
  22. except Exception as e:
  23. log.error(f"密码解密失败: {str(e)}")
  24. raise
  25. return encrypted_pass
  26. class DBHelper:
  27. def __init__(self):
  28. db_config = DB_CONFIG.copy()
  29. db_config['password'] = get_decrypted_password()
  30. self.engine = create_engine(
  31. f'mysql+pymysql://{db_config["user"]}:{db_config["password"]}@{db_config["host"]}:{db_config["port"]}/{db_config["database"]}?charset={db_config["charset"]}',
  32. pool_size=5,
  33. max_overflow=10
  34. )
  35. def get_commodity_id(self, name):
  36. """获取商品编码对应的分类ID[1,3](@ref)"""
  37. with self.engine.connect() as conn:
  38. result = conn.execute(
  39. text("SELECT id FROM t_yujin_crossborder_prov_commodity_category WHERE commodity_name = :name"),
  40. {'name': name}
  41. ).fetchone()
  42. return result[0] if result else None
  43. def bulk_insert(self, df, table_name, conflict_columns=None, update_columns=None):
  44. """
  45. 增强版批量插入(支持覆盖更新)
  46. :param df: 要插入的DataFrame
  47. :param table_name: 目标表名
  48. :param conflict_columns: 冲突检测字段列表
  49. :param update_columns: 需要更新的字段列表
  50. """
  51. if df.empty:
  52. log.info("空数据集,跳过插入")
  53. return
  54. # 生成带参数的SQL模板
  55. columns = ', '.join(df.columns)
  56. placeholders = ', '.join([f":{col}" for col in df.columns])
  57. sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders})"
  58. # 添加ON DUPLICATE KEY UPDATE(MySQL语法)
  59. if conflict_columns and update_columns:
  60. # 1. 处理用户指定的更新字段
  61. update_clauses = [f"{col}=VALUES({col})" for col in update_columns]
  62. # 2. 强制添加create_time=NOW()
  63. update_clauses.append("create_time = NOW()") # 新增
  64. # 3. 合并所有更新条件
  65. update_set = ', '.join(update_clauses)
  66. sql += f" ON DUPLICATE KEY UPDATE {update_set}"
  67. # 转换数据为字典列表格式
  68. data = df.to_dict(orient='records')
  69. # print("data:", data)
  70. try:
  71. with self.engine.connect() as conn:
  72. # 显式开启事务
  73. with conn.begin():
  74. # 使用text()包装SQL语句
  75. stmt = text(sql)
  76. # 批量执行
  77. conn.execute(stmt, data)
  78. log.info(f"成功插入/更新 {len(df)} 行到 {table_name}")
  79. except Exception as e:
  80. log.error(f"数据库操作失败: {str(e)}")
  81. raise
  82. def update_prov_yoy(self, prov_name):
  83. """
  84. 完整更新山东省同比数据(包含新旧数据处理)
  85. """
  86. try:
  87. # 步骤1:清理旧数据
  88. cleared = self.clear_old_prov_yoy(prov_name)
  89. # 步骤2:计算新数据
  90. updated = self._update_prov_new_yoy(prov_name)
  91. log.info(f"{prov_name}同比处理完成 | 清零:{cleared} 更新:{updated}")
  92. return {'cleared': cleared, 'updated': updated}
  93. except Exception as e:
  94. log.error(f"{prov_name}数据处理失败", exc_info=True)
  95. raise
  96. def clear_old_prov_yoy(self, prov_name):
  97. """
  98. 清理指定省份2024年前数据的同比指标
  99. """
  100. clear_sql = text("""
  101. UPDATE t_yujin_crossborder_prov_region_trade
  102. SET yoy_import_export = null,
  103. yoy_export = null,
  104. yoy_import = null
  105. WHERE prov_name = :prov_name
  106. AND crossborder_year_month < '2024-01'
  107. AND (yoy_import_export != 0
  108. OR yoy_export != 0
  109. OR yoy_import != 0) -- 优化:仅更新非零记录
  110. """)
  111. try:
  112. with self.engine.begin() as conn:
  113. result = conn.execute(clear_sql, {'prov_name': prov_name})
  114. log.info(f"{prov_name}旧数据清零记录数: {result.rowcount}")
  115. return result.rowcount
  116. except Exception as e:
  117. log.error(f"旧数据清零失败: {str(e)}")
  118. raise
  119. def _update_prov_new_yoy(self,prov_name):
  120. """
  121. 更新2024年及之后的省份城市同比数据
  122. """
  123. update_sql = text("""
  124. UPDATE t_yujin_crossborder_prov_region_trade AS curr
  125. INNER JOIN t_yujin_crossborder_prov_region_trade AS prev
  126. ON curr.city_code = prev.city_code
  127. AND prev.crossborder_year_month = DATE_FORMAT(
  128. DATE_SUB(
  129. STR_TO_DATE(CONCAT(curr.crossborder_year_month, '-01'), '%Y-%m-%d'),
  130. INTERVAL 1 YEAR
  131. ),
  132. '%Y-%m'
  133. )
  134. SET
  135. curr.yoy_import_export = COALESCE (
  136. TRUNCATE((curr.monthly_total - prev.monthly_total) / NULLIF (prev.monthly_total, 0) * 100, 4),
  137. 0.0000
  138. ),
  139. curr.yoy_import = COALESCE (
  140. TRUNCATE((curr.monthly_import - prev.monthly_import) / NULLIF (prev.monthly_import, 0) * 100, 4),
  141. 0.0000
  142. ),
  143. curr.yoy_export = COALESCE (
  144. TRUNCATE((curr.monthly_export - prev.monthly_export) / NULLIF (prev.monthly_export, 0) * 100, 4),
  145. 0.0000
  146. )
  147. WHERE
  148. curr.prov_name = :prov_name
  149. AND curr.crossborder_year_month >= '2024-01'
  150. AND prev.monthly_total IS NOT NULL
  151. """)
  152. with self.engine.begin() as conn:
  153. result = conn.execute(update_sql, {'prov_name': prov_name})
  154. log.info(f"{prov_name}新数据更新数: {result.rowcount}")
  155. return result.rowcount
  156. def query(self, sql, params=None, return_df=True):
  157. try:
  158. with self.engine.connect() as conn:
  159. if return_df:
  160. # 替代方法:使用 SQLAlchemy 结果代理直接创建 DataFrame
  161. result_proxy = conn.execute(text(sql), params or {})
  162. # 更健壮的方式获取列名
  163. columns = [col_desc[0] for col_desc in result_proxy.cursor.description]
  164. # 获取所有数据
  165. data = result_proxy.fetchall()
  166. # 手动创建 DataFrame
  167. df = pd.DataFrame(data, columns=columns)
  168. log.info(f"查询成功,返回 {len(df)} 条记录")
  169. return df
  170. else:
  171. result = conn.execute(text(sql), params or {}).fetchall()
  172. log.info(f"查询成功,返回 {len(result)} 条记录")
  173. return result
  174. except Exception as e:
  175. log.error(f"查询失败: {str(e)}")
  176. # 添加详细信息日志
  177. log.error(f"SQL: {sql}")
  178. log.error(f"Params: {params}")
  179. raise
  180. def execute_sql_with_params(self, sql: str, params_list: list):
  181. """
  182. 执行带参数的SQL语句(支持批量插入/更新)
  183. :param sql: 参数化的SQL语句(如含%s、%s等)
  184. :param params_list: 参数列表,每个元素是一个tuple或dict(根据SQL风格而定)
  185. :return: 受影响行数
  186. """
  187. try:
  188. with self.engine.connect() as conn:
  189. with conn.begin():
  190. # 使用text()包装原始SQL
  191. stmt = text(sql)
  192. # 判断是否为多组参数(批量插入)
  193. if isinstance(params_list[0], (list, tuple)):
  194. result = conn.execute(stmt, params_list)
  195. else:
  196. result = conn.execute(stmt, params_list)
  197. affected_rows = result.rowcount
  198. log.info(f"成功执行SQL,受影响行数:{affected_rows}")
  199. return affected_rows
  200. except Exception as e:
  201. log.error(f"SQL执行失败: {str(e)}")
  202. raise
  203. def get_code_exist(self, year_month, prov_code, is_city=True, customs_name=None, city_names=None):
  204. """
  205. 检查指定月份和地区在表中是否存在记录
  206. 参数:
  207. year_month: 年月字符串 (格式: YYYY-MM)
  208. prov_code: 省份代码
  209. is_city: 是否为城市级数据
  210. customs_name: 海关名称(可选)
  211. city_names: 城市名称列表(可选)
  212. 返回:
  213. 匹配的记录数量
  214. """
  215. # 获取表名
  216. table_name = self.get_table_name_by_province(prov_code, is_city)
  217. # 构建基础查询
  218. base_query = f"SELECT COUNT(*) FROM `{table_name}` WHERE crossborder_year_month = :year_month"
  219. params = {"year_month": year_month}
  220. # 添加地区条件
  221. conditions = []
  222. if customs_name and customs_name in CUSTOMS_CITY_MAPPING:
  223. # 根据海关获取对应的城市
  224. cities = CUSTOMS_CITY_MAPPING[customs_name]
  225. # 添加城市条件
  226. conditions.append(f"city_name IN :cities")
  227. params["cities"] = tuple(cities)
  228. elif city_names:
  229. # 直接使用提供的城市列表
  230. conditions.append(f"city_name IN :cities")
  231. params["cities"] = tuple(city_names)
  232. elif is_city:
  233. # 默认添加省份代码条件
  234. conditions.append(f"prov_code = :prov_code")
  235. params["prov_code"] = prov_code
  236. # 组合查询条件
  237. if conditions:
  238. base_query += " AND " + " AND ".join(conditions)
  239. try:
  240. query = text(base_query)
  241. with self.engine.connect() as connection:
  242. result = connection.execute(query, params).scalar()
  243. return result or 0
  244. except SQLAlchemyError as e:
  245. log.error(f"查询错误: {str(e)}")
  246. return -1
  247. # 辅助函数:根据省份代码获取表名(示例实现,按需修改)
  248. def get_table_name_by_province(self, prov_code, is_city=True):
  249. """
  250. 根据省份代码和数据类型返回对应表名
  251. """
  252. # 主要表名映射规则
  253. # 350000=福建,370000=山东,410000=河南,440000=广东
  254. table_mapping = {
  255. "350000": "t_yujin_crossborder_prov_region_trade",
  256. "370000": "t_yujin_crossborder_prov_region_trade",
  257. "410000": "t_yujin_crossborder_prov_commodity_trade",
  258. "440000": self.get_guangdong_table(is_city) # 特殊处理广东省
  259. }
  260. if prov_code not in table_mapping:
  261. raise ValueError(f"不支持省份代码: {prov_code}")
  262. return table_mapping[prov_code]
  263. # 新增方法:处理广东省的特殊情况
  264. def get_guangdong_table(self, is_city):
  265. """
  266. 根据数据类型返回广东省对应的表名
  267. """
  268. if is_city:
  269. return "t_yujin_crossborder_prov_region_trade"
  270. else:
  271. return "t_yujin_crossborder_prov_commodity_trade"
  272. def get_total_info_exist(self, year_month):
  273. query = text(f"""
  274. SELECT COUNT(*) FROM `t_yujin_crossborder_region_trade`
  275. WHERE `year_month` = :year_month
  276. """)
  277. try:
  278. with self.engine.connect() as connection:
  279. result = connection.execute(
  280. query,
  281. {"year_month": year_month}
  282. ).scalar()
  283. return result or 0
  284. except SQLAlchemyError as e:
  285. log.error(f"查询错误: {str(e)}")
  286. return -1 # 表示查询出错