from flask import Flask, request, abort, jsonify
|
import asyncio
|
from database import DatabaseManager
|
from wecom_utils import WeComUtils
|
import time
|
import threading
|
import logging
|
import requests
|
import json
|
from redis_manager import RedisManager
|
import os
|
from config_manager import global_config
|
|
|
app = Flask(__name__)
|
# app.config['CELERY_BROKER_URL'] = 'redis://localhost:6379/0'
|
# app.config['CELERY_RESULT_BACKEND'] = 'redis://localhost:6379/0'
|
# celery = Celery(app.name, broker=app.config['CELERY_BROKER_URL'])
|
# celery.conf.update(app.config)
|
db_manager = DatabaseManager()
|
redis_manager = RedisManager()
|
|
# 初始化配置管理器
|
config = global_config.get_config()
|
app.config.from_object(config)
|
|
|
def load_config():
|
# 优先级1:当前工作目录的 config.json
|
if os.path.exists('config.json'):
|
with open('config.json') as f:
|
return json.load(f)
|
raise FileNotFoundError("配置文件 config.json 不存在")
|
|
|
# 配置日志
|
logging.basicConfig(
|
level=logging.DEBUG,
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
handlers=[
|
logging.StreamHandler(),
|
logging.FileHandler('app.log')
|
]
|
)
|
logger = logging.getLogger(__name__)
|
|
|
# 初始化函数
|
def initialize_app():
|
"""初始化应用"""
|
try:
|
logger.info("Initializing application...")
|
db_manager.init_db()
|
# db_manager.cache_all_data()
|
logger.info("Database initialized and data cached")
|
except Exception as e:
|
logger.error(f"Initialization error: {e}")
|
raise
|
|
|
# 后台任务
|
def start_background_tasks():
|
"""启动后台任务"""
|
def message_persister():
|
logger.info("Message persister started")
|
while True:
|
try:
|
persist_expired_messages()
|
time.sleep(60)
|
except Exception as e:
|
logger.error(f"Message persister error: {e}")
|
time.sleep(60)
|
|
threading.Thread(target=message_persister, daemon=True).start()
|
|
|
def persist_expired_messages():
|
"""持久化过期消息"""
|
|
pattern = f"{config.REDIS_KEY_PREFIX}:user_messages:*"
|
# 使用集合来存储唯一的哈希ID
|
unique_ids = set()
|
# 使用SCAN迭代所有匹配的键
|
cursor = '0'
|
while cursor != 0:
|
cursor, keys = redis_manager.redis.scan(cursor=cursor, match=pattern)
|
for key in keys:
|
# 在Python3中,key已经是字符串,不需要decode
|
if isinstance(key, bytes):
|
key_str = key.decode('utf-8') # 保险起见,仍然处理bytes情况
|
else:
|
key_str = key
|
# 提取哈希ID部分
|
parts = key_str.split(':')
|
if len(parts) >= 3:
|
unique_ids.add(parts[2])
|
|
# 转换为列表
|
keys = list(unique_ids)
|
# keys = redis_manager.redis.keys("user_messages:*")
|
current_time = time.time()
|
print("执行持久化")
|
print(f"keys:{keys}")
|
for key in keys:
|
# user_id = key.split(":")[1]
|
messages = redis_manager.get_user_messages(key)
|
conversation_id = redis_manager.get_user_conversation_id(key)
|
print(f"user_id:{key}")
|
print(f"messages:{messages}")
|
print(f"conversation_id:{conversation_id}")
|
if not messages:
|
continue
|
last_message_time = messages[-1]["timestamp"]
|
if current_time - last_message_time > config.MYSQL_PERSIST_EXPIRE:
|
try:
|
db_manager.save_message_history(key, messages, conversation_id)
|
redis_manager.clear_user_messages(key)
|
logger.info(f"Persisted messages for user: {key}")
|
except Exception as e:
|
logger.error(f"Failed to persist messages for user {key}: {e}")
|
|
|
@app.route('/wecom/callback', methods=['GET', 'POST'])
|
async def wecom_callback():
|
"""企业微信机器人回调接口"""
|
logger.info(f"Received {request.method} request at /wecom/callback")
|
response = {}
|
# 初始化检查
|
if not hasattr(app, 'initialized'):
|
try:
|
initialize_app()
|
start_background_tasks()
|
app.initialized = True
|
logger.info("Application initialized")
|
except Exception as e:
|
logger.error(f"Initialization failed: {e}")
|
abort(500, "Server initialization failed")
|
|
if request.method == 'GET':
|
# URL验证
|
msg_signature = request.args.get('msg_signature', '')
|
timestamp = request.args.get('timestamp', '')
|
nonce = request.args.get('nonce', '')
|
echostr = request.args.get('echostr', '')
|
|
ret, reply_echostr = WeComUtils.verify_url(timestamp, nonce, echostr, msg_signature)
|
if ret == 0:
|
logger.info("Signature verification successful")
|
return reply_echostr
|
else:
|
logger.error("Signature verification failed")
|
abort(403, "Signature verification failed")
|
|
elif request.method == 'POST':
|
# 处理消息
|
msg_signature = request.args.get('msg_signature', '')
|
timestamp = request.args.get('timestamp', '')
|
nonce = request.args.get('nonce', '')
|
|
logger.info(f"POST request params: msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}")
|
|
# 获取加密消息
|
post_data = request.data.decode('utf-8')
|
logger.info(f"post data: {post_data}")
|
|
msg_data = {}
|
# 解密消息
|
try:
|
ret, msg_data = WeComUtils.decrypt_msg(post_data, msg_signature, timestamp, nonce)
|
# 将字节串转换为字符串
|
msg_data = msg_data.decode('utf-8')
|
# 将字符串转换为字典
|
msg_data = json.loads(msg_data)
|
|
except Exception as e:
|
logger.error(f"Error processing message: {e}")
|
abort(400, "Failed to process message")
|
|
# 处理文本消息
|
if msg_data['msgtype'] == 'text':
|
asyncio.create_task(handler_text_request(msg_data, nonce, timestamp))
|
stream_id = WeComUtils.generate_stream_id(msg_data['from']['userid'], msg_data['msgid'])
|
# 响应空stream数据
|
response_text = WeComUtils.create_response_message(stream_id, first=True)
|
# 加密
|
ret, response = WeComUtils.encrypt_msg(response_text, nonce, timestamp)
|
return jsonify(response)
|
# 处理流式刷新
|
elif msg_data['msgtype'] == 'stream':
|
response = handler_stream_request(msg_data, nonce, timestamp)
|
return jsonify(response)
|
else:
|
logger.info(f"Ignoring non-text message of type or non-group chat of type")
|
return 'success'
|
|
|
async def handler_text_request(msg_data, nonce, timestamp):
|
# 获取用户ID和群ID(如果存在)
|
user_id = msg_data['from']['userid']
|
msg_id = msg_data['msgid']
|
stream_id = WeComUtils.generate_stream_id(user_id, msg_id)
|
# 将内容中的@AI人工客服 替换
|
msg_content = msg_data['text']['content'].replace(config.WECOM_BOT_NAME, "")
|
# 删除左右空格
|
msg_content = msg_content.strip()
|
# 打印请求内容
|
print(f"{stream_id}-user: {msg_content}")
|
|
exists_stream = redis_manager.exists_stream(stream_id)
|
# 存储用户消息
|
try:
|
if not exists_stream:
|
# 如果当前stream不存在则保存用户消息
|
redis_manager.add_user_message(user_id, 'user', msg_content)
|
# 不存在添加stream_id
|
redis_manager.add_stream_lock(stream_id, '0')
|
|
logger.info(f"Stored user message in Redis")
|
except Exception as e:
|
logger.error(f"Failed to store user message: {e}")
|
|
# 获取conversation_id
|
conversation_id = redis_manager.get_user_conversation_id(user_id)
|
|
# 调用Dify API
|
try:
|
# stream不存在就调用ai请求任务
|
if not exists_stream:
|
send_chat_message(query=msg_content, msgid=msg_id, user=user_id, inputs={}, files={}, conversation_id=conversation_id, stream_id=stream_id)
|
except Exception as e:
|
logger.error(f"Failed to call Dify API: {e}")
|
|
|
def handler_stream_request(msg_data, nonce, timestamp):
|
stream_id = msg_data['stream']['id']
|
|
# # 结束所有请求
|
# response_message = {
|
# "msgtype": "stream",
|
# "stream": {
|
# "id": stream_id,
|
# "finish": True,
|
# "content": "123"
|
# }
|
# }
|
# response_message = json.dumps(response_message, ensure_ascii=False)
|
# # 加密
|
# ret, response2 = WeComUtils.encrypt_msg(response_message, nonce, timestamp)
|
|
# 判断是否存在stream_id
|
if redis_manager.exists_stream(stream_id):
|
# 存在则进行响应
|
response_text = WeComUtils.create_response_message(stream_id)
|
# 加密
|
ret, response = WeComUtils.encrypt_msg(response_text, nonce, timestamp)
|
return response
|
else:
|
# 否则响应空
|
# 存在则进行响应
|
response_text = WeComUtils.create_response_message(stream_id, first=True)
|
# 加密
|
ret, response = WeComUtils.encrypt_msg(response_text, nonce, timestamp)
|
logger.debug(f"流式 json: {response}")
|
return response
|
|
|
def send_chat_message(
|
query,
|
msgid,
|
stream_id,
|
conversation_id=None,
|
user="default_user",
|
api_key=config.DIFY_API_KEY,
|
base_url=config.DIFY_API_BASE_URL,
|
inputs=None,
|
auto_generate_name=True,
|
files=None,
|
):
|
"""
|
发送聊天消息到 Dify API
|
|
参数:
|
query (str): 用户输入/提问内容
|
conversation_id (str, optional): 会话 ID,用于继续之前的对话
|
user (str): 用户标识,用于定义终端用户的身份
|
api_key (str): Dify API 密钥
|
base_url (str): Dify API 基础 URL
|
inputs (dict, optional): App 定义的各变量值
|
auto_generate_name (bool): 是否自动生成标题
|
|
返回:
|
dict: API 响应结果
|
|
异常:
|
requests.exceptions.RequestException: 网络请求异常
|
ValueError: 参数验证失败
|
"""
|
|
# 参数验证
|
if not query:
|
raise ValueError("query 参数不能为空")
|
|
if not user:
|
raise ValueError("user 参数不能为空")
|
|
if not api_key or api_key == "YOUR_API_KEY":
|
raise ValueError("请设置有效的 API 密钥")
|
|
# 构建请求 URL
|
url = f"{base_url.rstrip('/')}/chat-messages"
|
|
# 构建请求头
|
headers = {
|
"Authorization": f"Bearer {api_key}",
|
"Content-Type": "application/json"
|
}
|
|
# 构建请求体
|
payload = {
|
"query": query,
|
"response_mode": "blocking",
|
"user": user,
|
"auto_generate_name": auto_generate_name
|
}
|
|
# 添加可选参数
|
if conversation_id:
|
payload["conversation_id"] = conversation_id
|
|
payload["inputs"] = {}
|
|
payload["files"] = {}
|
|
try:
|
res = {
|
"answer": "",
|
"conversation_id": "",
|
"status": 0
|
}
|
# 发送 POST 请求
|
response = requests.post(
|
url=url,
|
headers=headers,
|
json=payload,
|
timeout=120 # 设置超时时间为 120 秒
|
)
|
|
# 检查响应状态码
|
if response.status_code == 200:
|
response_json = response.json()
|
res = {
|
"answer": response_json.get("answer", ""),
|
"conversation_id": response_json.get("conversation_id", ""),
|
"status": 200
|
}
|
else:
|
# 处理错误响应
|
try:
|
error_data = response.json()
|
error_msg = f"API 错误: {response.status_code} - {error_data.get('message', '未知错误')}"
|
res = {
|
"answer": error_msg,
|
"conversation_id": conversation_id,
|
"status": response.status_code
|
}
|
except json.JSONDecodeError:
|
error_msg = f"API 错误: {response.status_code} - {response.text}"
|
res = {
|
"answer": error_msg,
|
"conversation_id": conversation_id,
|
"status": response.status_code
|
}
|
|
raise requests.exceptions.RequestException(error_msg)
|
|
except requests.exceptions.Timeout:
|
res = {
|
"answer": "请求超时,请检查网络连接或稍后重试",
|
"conversation_id": conversation_id,
|
"status": 408
|
}
|
raise requests.exceptions.RequestException("请求超时,请检查网络连接或稍后重试")
|
|
except requests.exceptions.ConnectionError:
|
res = {
|
"answer": "连接失败,请检查网络连接和 API 地址",
|
"conversation_id": conversation_id,
|
"status": 400
|
}
|
raise requests.exceptions.RequestException("连接失败,请检查网络连接和 API 地址")
|
|
except requests.exceptions.RequestException as e:
|
res = {
|
"answer": "请求错误",
|
"conversation_id": conversation_id,
|
"status": 998
|
}
|
raise e
|
|
except Exception as e:
|
res = {
|
"answer": "未知错误",
|
"conversation_id": conversation_id,
|
"status": 999
|
}
|
raise requests.exceptions.RequestException(f"未知错误: {str(e)}")
|
|
# 状态为200则存储ai回复
|
if res['status'] == 200 or res['status'] == "200":
|
# 存储AI回复
|
redis_manager.add_user_message(user, 'assistant', res['answer'])
|
# 存储conversation_id
|
redis_manager.add_conversation_id(user, res['conversation_id'])
|
|
# redis_manager.update_wecom_msgid(msgid, res)
|
# 分割响应内容
|
chunks = WeComUtils.split_string_safely(res['answer'])
|
# 将字节块列表添加到redis
|
redis_manager.add_stream_chunks(stream_id, chunks)
|
redis_manager.add_stream_status(stream_id, res['status'])
|
redis_manager.add_conversation_id(user, res['conversation_id'])
|
# 打印ai回复内容
|
print(f"{stream_id}-assistant: {res['answer']}")
|
|
if __name__ == '__main__':
|
# 初始化应用
|
try:
|
initialize_app()
|
start_background_tasks()
|
app.initialized = True
|
except Exception as e:
|
logger.critical(f"Failed to initialize application: {e}")
|
exit(1)
|
|
# 启动应用
|
logger.info("Starting Flask application")
|
app.run(host='0.0.0.0', port=5959)
|