from sqlalchemy import create_engine
|
from sqlalchemy.orm import sessionmaker
|
from models import Base, ChatGroup, User, GroupUser, GroupSMSAccount, MessageHistory
|
from redis_manager import RedisManager
|
from config_manager import global_config
|
|
|
class DatabaseManager:
|
def __init__(self):
|
config = global_config.get_config()
|
self.engine = create_engine(
|
f"mysql+pymysql://{config.MYSQL_USER}:{config.MYSQL_PASSWORD}@{config.MYSQL_HOST}:{config.MYSQL_PORT}/{config.MYSQL_DB}"
|
)
|
self.Session = sessionmaker(bind=self.engine)
|
self.redis = RedisManager()
|
|
def init_db(self):
|
"""初始化数据库表"""
|
Base.metadata.create_all(self.engine)
|
|
# def cache_all_data(self):
|
# """缓存所有数据到Redis"""
|
# session = self.Session()
|
#
|
# try:
|
# # 缓存所有群组
|
# groups = session.query(ChatGroup).all()
|
# for group in groups:
|
# self.redis.cache_group_data(group.id, "info", {"id": group.id, "name": group.name})
|
#
|
# # 缓存群组成员
|
# user_ids = [gu.user_id for gu in group.members]
|
# self.redis.cache_group_data(group.id, "users", user_ids)
|
#
|
# # 缓存短信账号
|
# sms_accounts = [ga.sms_account for ga in group.sms_accounts]
|
# self.redis.cache_group_data(group.id, "sms_accounts", sms_accounts)
|
#
|
# # 缓存所有用户
|
# users = session.query(User).all()
|
# for user in users:
|
# self.redis.set(f"user:{user.id}", json.dumps({"id": user.id, "name": user.name}))
|
#
|
# finally:
|
# session.close()
|
|
def save_message_history(self, user_id, messages, conversation_id):
|
"""保存消息历史到数据库"""
|
session = self.Session()
|
|
try:
|
for msg in messages:
|
new_msg = MessageHistory(
|
user_id=user_id,
|
conversation_id=conversation_id,
|
role=msg["role"],
|
content=msg["content"]
|
)
|
session.add(new_msg)
|
session.commit()
|
except Exception as e:
|
session.rollback()
|
raise e
|
finally:
|
session.close()
|