zhangfan пре 1 месец
родитељ
комит
4be56fc798

+ 59 - 25
crossborder/utils/base_mysql.py

@@ -1,8 +1,10 @@
+from urllib.parse import quote_plus
+
 import pymysql
 from sqlalchemy import create_engine, text
-from urllib.parse import quote_plus
 
 from crossborder.utils.log import log
+from crossborder.utils.crypto_utils import AESCryptor
 
 provinces = [
     "北京市", "天津市", "上海市", "重庆市",
@@ -16,16 +18,30 @@ provinces = [
     "宁夏回族自治区", "新疆维吾尔自治区"
 ]
 
+cryptor = AESCryptor("uat_ff419620e7047a3c372e2513c5a2b9a5")
+
 # 数据库配置
 DB_CONFIG = {
     'host': '10.130.75.149',
     'port': 3307,
     'user': 'yto_crm',
-    'password': '%3sFUlsolaRI',
+    'password': 'ENC(Fl9g4899OmVYddM42Rt2fA==:sDy1QG/7bmx/iHo4xEOBGQ==)',
+    # 'password': '%3sFUlsolaRI',
     'database': 'crm_uat',
     'charset': 'utf8mb4'
 }
 
+# 修改解密函数
+def get_decrypted_password():
+    encrypted_pass = DB_CONFIG['password']
+    if encrypted_pass.startswith("ENC("):
+        try:
+            return cryptor.decrypt(encrypted_pass)
+        except Exception as e:
+            log.error(f"密码解密失败: {str(e)}")
+            raise
+    return encrypted_pass
+
 def get_commodity_id(commodity_name):
     """根据商品名称查询数据库,获取商品 ID 和商品名称"""
     fix_commodity_name = commodity_name
@@ -33,9 +49,12 @@ def get_commodity_id(commodity_name):
         fix_commodity_name = commodity_name.rsplit("(")[0] or commodity_name.rsplit("(")[0]
     fix_commodity_name = fix_commodity_name.replace('*', '').replace('#', '').replace('“', '').replace('”', '').replace('。', '')
 
+    connection = None
     try:
         # 连接数据库
-        connection = pymysql.connect(**DB_CONFIG)
+        db_config = DB_CONFIG.copy()
+        db_config['password'] = get_decrypted_password()
+        connection = pymysql.connect(**db_config)
         with connection.cursor() as cursor:
             # 执行查询
             sql = "SELECT e.id, e.commodity_name FROM t_yujin_crossborder_prov_commodity_category e WHERE e.commodity_name like %s"
@@ -75,7 +94,9 @@ def get_commodity_id(commodity_name):
 def get_hs_all():
     try:
         # 连接数据库
-        connection = pymysql.connect(**DB_CONFIG)
+        db_config = DB_CONFIG.copy()
+        db_config['password'] = get_decrypted_password()
+        connection = pymysql.connect(**db_config)
         with connection.cursor() as cursor:
             # 执行查询
             sql = "SELECT e.id,e.category_name FROM t_yujin_crossborder_hs_category e"
@@ -95,7 +116,9 @@ def get_hs_all():
 def get_code_exist(crossborder_year_month, prov_code):
     try:
         # 使用 with 自动管理连接生命周期
-        with pymysql.connect(**DB_CONFIG) as connection:
+        db_config = DB_CONFIG.copy()
+        db_config['password'] = get_decrypted_password()
+        with pymysql.connect(**db_config) as connection:
             with connection.cursor() as cursor:
                 # 执行查询
                 sql = """
@@ -111,17 +134,6 @@ def get_code_exist(crossborder_year_month, prov_code):
         log.info(f"[数据库查询异常] 查询条件: {crossborder_year_month}, {prov_code} | 错误详情: {str(e)}")
         return 0
 
-
-# 对密码进行 URL 编码
-encoded_password = quote_plus(DB_CONFIG["password"])
-
-# 构建 SQLAlchemy 引擎
-engine = create_engine(
-    f"mysql+pymysql://{DB_CONFIG['user']}:{encoded_password}@{DB_CONFIG['host']}:{DB_CONFIG['port']}/{DB_CONFIG['database']}?charset={DB_CONFIG['charset']}",
-    pool_size=5,
-    max_overflow=10
-)
-
 def bulk_insert(sql_statements):
     """
     批量执行 SQL 插入语句
@@ -131,16 +143,36 @@ def bulk_insert(sql_statements):
         log.info("未提供有效的 SQL 插入语句,跳过操作")
         return
 
+    connection = None
     try:
-        with engine.connect() as conn:
-            with conn.begin():
-                for sql in sql_statements:
-                    stmt = text(sql.strip())
-                    conn.execute(stmt)
-                log.info(f"成功执行 {len(sql_statements)} 条 SQL 插入语句")
+        # 使用解密后的密码创建连接
+        db_config = DB_CONFIG.copy()
+        db_config['password'] = get_decrypted_password()
+
+        # 创建连接并开启事务
+        connection = pymysql.connect(**db_config)
+        connection.begin()  # 显式开始事务
+
+        with connection.cursor() as cursor:
+            # 遍历执行所有 SQL 语句
+            for sql in sql_statements:
+                # 移除 SQL 两端空白并执行
+                cursor.execute(sql.strip())
+
+            # 提交事务
+            connection.commit()
+            log.info(f"成功执行 {len(sql_statements)} 条 SQL 插入语句")
+
     except Exception as e:
+        # 回滚事务并记录错误
+        if connection:
+            connection.rollback()
         log.info(f"数据库操作失败: {str(e)}")
         raise
+    finally:
+        # 确保连接关闭
+        if connection:
+            connection.close()
 
 def update_january_yoy(prov_name):
     """
@@ -397,14 +429,16 @@ def _update_shandong_new_yoy_origin(region_name):
         return result.rowcount
 
 if __name__ == '__main__':
+    commodity_code, commodity_name_fix = get_commodity_id('农产品')
+    print(commodity_code, commodity_name_fix)
     # check_year, check_month = 2024, 4
     # count = get_code_exist(f'{check_year}-{check_month:02d}', "340000")
     # print(count)
 
     # 新表更新地级市同比
-    for province in provinces:
-        update_january_yoy(province)
-        update_shandong_yoy(province)
+    # for province in provinces:
+    #     update_january_yoy(province)
+    #     update_shandong_yoy(province)
 
     # 旧表更新省份同比
     # for province in provinces:

+ 45 - 0
crossborder/utils/crypto_utils.py

@@ -0,0 +1,45 @@
+import base64
+import os
+
+from Crypto.Cipher import AES
+from Crypto.Protocol.KDF import PBKDF2
+from Crypto.Util.Padding import pad, unpad
+
+
+class AESCryptor:
+    def __init__(self, password, salt=None, iterations=1000):
+        """
+        初始化 AES 加密器
+        :param password: 主密码(用于生成密钥)
+        :param salt: 盐值(建议固定或通过环境变量注入)
+        :param iterations: 密钥派生迭代次数
+        """
+        self.password = password
+        self.salt = salt or os.getenv("AES_SALT", "default_salt")  # 生产环境应替换为安全盐值
+        self.iterations = iterations
+        self.key = PBKDF2(self.password, self.salt.encode(), dkLen=32, count=self.iterations)
+
+    def encrypt(self, raw_text):
+        """AES加密(CBC模式)"""
+        cipher = AES.new(self.key, AES.MODE_CBC)
+        ct_bytes = cipher.encrypt(pad(raw_text.encode(), AES.block_size))
+        iv = base64.b64encode(cipher.iv).decode('utf-8')
+        ct = base64.b64encode(ct_bytes).decode('utf-8')
+        return f"ENC({iv}:{ct})"
+
+    def decrypt(self, encrypted_text):
+        """AES解密(CBC模式)"""
+        if not encrypted_text.startswith("ENC("):
+            return encrypted_text
+
+        try:
+            encrypted_data = encrypted_text[4:-1]  # 去除 ENC() 包裹
+            iv_str, ct_str = encrypted_data.split(":")
+            iv = base64.b64decode(iv_str)
+            ct = base64.b64decode(ct_str)
+
+            cipher = AES.new(self.key, AES.MODE_CBC, iv)
+            pt = unpad(cipher.decrypt(ct), AES.block_size)
+            return pt.decode('utf-8')
+        except Exception as e:
+            raise ValueError(f"解密失败: {str(e)}")

+ 15 - 12
crossborder/utils/db_helper.py

@@ -1,7 +1,6 @@
-
-
 import pandas as pd
 from sqlalchemy import create_engine, text
+from crossborder.utils.crypto_utils import AESCryptor
 
 from crossborder.utils.log import log
 
@@ -9,26 +8,30 @@ DB_CONFIG = {
     'host': '10.130.75.149',
     'port': 3307,
     'user': 'yto_crm',
-    'password': '%3sFUlsolaRI',
+    'password': 'ENC(Fl9g4899OmVYddM42Rt2fA==:sDy1QG/7bmx/iHo4xEOBGQ==)',
     'database': 'crm_uat',
     'charset': 'utf8mb4'
 }
 
-# DB_CONFIG = {
-#     'host': '10.130.36.185',
-#     'port': 3306,
-#     'user': 'user_ytexp',
-#     'password': 'Rn9ib3L1C4b4%40123',
-#     'database': 'yto_crm',
-#     'charset': 'utf8mb4'
-# }
+cryptor = AESCryptor("uat_ff419620e7047a3c372e2513c5a2b9a5")
 
 
+def get_decrypted_password():
+    encrypted_pass = DB_CONFIG['password']
+    if encrypted_pass.startswith("ENC("):
+        try:
+            return cryptor.decrypt(encrypted_pass)
+        except Exception as e:
+            log.error(f"密码解密失败: {str(e)}")
+            raise
+    return encrypted_pass
 
 class DBHelper:
     def __init__(self):
+        db_config = DB_CONFIG.copy()
+        db_config['password'] = get_decrypted_password()
         self.engine = create_engine(
-            f'mysql+pymysql://{DB_CONFIG["user"]}:{DB_CONFIG["password"]}@{DB_CONFIG["host"]}:{DB_CONFIG["port"]}/{DB_CONFIG["database"]}?charset={DB_CONFIG["charset"]}',
+            f'mysql+pymysql://{db_config["user"]}:{db_config["password"]}@{db_config["host"]}:{db_config["port"]}/{db_config["database"]}?charset={db_config["charset"]}',
             pool_size=5,
             max_overflow=10
         )

+ 12 - 0
crossborder/utils/generate_encrypted_password.py

@@ -0,0 +1,12 @@
+# generate_encrypted_password.py
+from crypto_utils import AESCryptor
+import os
+
+# 使用与 base_mysql.py 相同的主密码和盐值
+AES_PASSWORD = os.getenv("AES_PASSWORD", "uat_ff419620e7047a3c372e2513c5a2b9a5")
+cryptor = AESCryptor(AES_PASSWORD)
+
+# 替换为你的实际密码
+raw_password = '%3sFUlsolaRI'
+encrypted = cryptor.encrypt(raw_password)
+print(f"加密后的密码: {encrypted}")