"""
|
测试Dify流式模式功能
|
"""
|
|
import pytest
|
import json
|
from unittest.mock import Mock, patch, MagicMock
|
from app.services.dify_client import DifyClient
|
from config import settings
|
|
|
class TestDifyStreaming:
|
"""测试Dify流式模式"""
|
|
def setup_method(self):
|
"""测试前设置"""
|
self.client = DifyClient()
|
|
def test_process_stream_response_success(self):
|
"""测试成功处理流式响应"""
|
# 模拟流式响应数据
|
stream_data = [
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \"Hello\", \"created_at\": 1705398420}",
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \" World\", \"created_at\": 1705398420}",
|
"data: {\"event\": \"message_end\", \"metadata\": {\"usage\": {\"total_tokens\": 10}}, \"usage\": {\"total_tokens\": 10}}",
|
]
|
|
# 创建模拟响应对象
|
mock_response = Mock()
|
mock_response.iter_lines.return_value = stream_data
|
|
# 测试处理流式响应
|
result = self.client._process_stream_response(mock_response, "test_user")
|
|
# 验证结果
|
assert result is not None
|
assert result["answer"] == "Hello World"
|
assert result["conversation_id"] == "test-conv"
|
assert result["task_id"] == "test-task"
|
assert result["usage"]["total_tokens"] == 10
|
|
def test_process_stream_response_error(self):
|
"""测试处理流式响应错误"""
|
# 模拟错误响应数据
|
stream_data = [
|
"data: {\"event\": \"error\", \"message\": \"API调用失败\", \"code\": \"500\"}",
|
]
|
|
# 创建模拟响应对象
|
mock_response = Mock()
|
mock_response.iter_lines.return_value = stream_data
|
|
# 测试处理流式响应
|
result = self.client._process_stream_response(mock_response, "test_user")
|
|
# 验证结果
|
assert result is None
|
|
def test_process_stream_response_incomplete(self):
|
"""测试处理不完整的流式响应"""
|
# 模拟不完整响应数据(缺少message_end事件)
|
stream_data = [
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \"Hello\", \"created_at\": 1705398420}",
|
]
|
|
# 创建模拟响应对象
|
mock_response = Mock()
|
mock_response.iter_lines.return_value = stream_data
|
|
# 测试处理流式响应
|
result = self.client._process_stream_response(mock_response, "test_user")
|
|
# 验证结果 - 即使没有message_end事件,只要有内容和conversation_id也应该返回结果
|
assert result is not None
|
assert result["answer"] == "Hello"
|
assert result["conversation_id"] == "test-conv"
|
|
@patch('app.services.dify_client.settings')
|
def test_send_message_uses_streaming_when_enabled(self, mock_settings):
|
"""测试当启用流式模式时使用流式发送"""
|
# 设置配置为启用流式模式
|
mock_settings.dify_streaming_enabled = True
|
|
# 模拟流式发送方法
|
with patch.object(self.client, 'send_chat_message_stream') as mock_stream:
|
mock_stream.return_value = {"answer": "test response", "conversation_id": "test-conv"}
|
|
result = self.client.send_message("test query", "test_user")
|
|
# 验证调用了流式方法
|
mock_stream.assert_called_once_with("test query", "test_user", None, None)
|
assert result["answer"] == "test response"
|
|
@patch('app.services.dify_client.settings')
|
def test_send_message_uses_blocking_when_disabled(self, mock_settings):
|
"""测试当禁用流式模式时使用阻塞发送"""
|
# 设置配置为禁用流式模式
|
mock_settings.dify_streaming_enabled = False
|
|
# 模拟阻塞发送方法
|
with patch.object(self.client, 'send_chat_message') as mock_blocking:
|
mock_blocking.return_value = {"answer": "test response", "conversation_id": "test-conv"}
|
|
result = self.client.send_message("test query", "test_user")
|
|
# 验证调用了阻塞方法
|
mock_blocking.assert_called_once_with("test query", "test_user", None, None)
|
assert result["answer"] == "test response"
|
|
@patch('app.services.dify_client.settings')
|
def test_send_message_force_streaming_override(self, mock_settings):
|
"""测试强制流式模式覆盖配置"""
|
# 设置配置为禁用流式模式
|
mock_settings.dify_streaming_enabled = False
|
|
# 模拟流式发送方法
|
with patch.object(self.client, 'send_chat_message_stream') as mock_stream:
|
mock_stream.return_value = {"answer": "test response", "conversation_id": "test-conv"}
|
|
# 强制使用流式模式
|
result = self.client.send_message("test query", "test_user", force_streaming=True)
|
|
# 验证调用了流式方法(覆盖了配置)
|
mock_stream.assert_called_once_with("test query", "test_user", None, None)
|
assert result["answer"] == "test response"
|
|
def test_process_stream_response_with_ping_events(self):
|
"""测试处理包含ping事件的流式响应"""
|
# 模拟包含ping事件的响应数据
|
stream_data = [
|
"data: {\"event\": \"ping\"}",
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \"Hello\", \"created_at\": 1705398420}",
|
"data: {\"event\": \"ping\"}",
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \" World\", \"created_at\": 1705398420}",
|
"data: {\"event\": \"message_end\", \"metadata\": {\"usage\": {\"total_tokens\": 10}}}",
|
]
|
|
# 创建模拟响应对象
|
mock_response = Mock()
|
mock_response.iter_lines.return_value = stream_data
|
|
# 测试处理流式响应
|
result = self.client._process_stream_response(mock_response, "test_user")
|
|
# 验证结果(ping事件应该被忽略)
|
assert result is not None
|
assert result["answer"] == "Hello World"
|
assert result["conversation_id"] == "test-conv"
|
|
def test_process_stream_response_with_invalid_json(self):
|
"""测试处理包含无效JSON的流式响应"""
|
# 模拟包含无效JSON的响应数据
|
stream_data = [
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \"Hello\", \"created_at\": 1705398420}",
|
"data: invalid json data",
|
"data: {\"event\": \"message\", \"task_id\": \"test-task\", \"id\": \"test-msg\", \"conversation_id\": \"test-conv\", \"answer\": \" World\", \"created_at\": 1705398420}",
|
"data: {\"event\": \"message_end\"}",
|
]
|
|
# 创建模拟响应对象
|
mock_response = Mock()
|
mock_response.iter_lines.return_value = stream_data
|
|
# 测试处理流式响应
|
result = self.client._process_stream_response(mock_response, "test_user")
|
|
# 验证结果(无效JSON应该被跳过)
|
assert result is not None
|
assert result["answer"] == "Hello World"
|
assert result["conversation_id"] == "test-conv"
|
|
|
if __name__ == "__main__":
|
pytest.main([__file__])
|