220 lines
8.3 KiB
Python
220 lines
8.3 KiB
Python
import json
|
|
import random
|
|
import time
|
|
import sys
|
|
import traceback
|
|
from pathlib import Path
|
|
from flask import Flask, request, jsonify, Response, stream_with_context
|
|
from typing import Generator
|
|
|
|
# Add parent directory to path for imports
|
|
parent_dir = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(parent_dir))
|
|
|
|
from api import Client
|
|
from dto import ChatCompletionRequest, ChatCompletionResponse, Choice, Message, ChunkResponse, ChunkChoice, Delta, Model
|
|
from kv import Cache, ChatData
|
|
from solver import Solver
|
|
|
|
|
|
class Application:
|
|
def __init__(self, solver: Solver, cache: Cache):
|
|
self.solver = solver
|
|
self.cache = cache
|
|
self.app = Flask(__name__)
|
|
self._setup_routes()
|
|
|
|
def _setup_routes(self):
|
|
"""Setup Flask routes"""
|
|
|
|
@self.app.route('/', methods=['GET'])
|
|
def health():
|
|
return "started", 200
|
|
|
|
@self.app.route('/models', methods=['GET'])
|
|
def models():
|
|
models_list = {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "r1",
|
|
"object": "model",
|
|
"owned_by": "deepseek",
|
|
},
|
|
{
|
|
"id": "deepseek-chat",
|
|
"object": "model",
|
|
"owned_by": "deepseek",
|
|
},
|
|
{
|
|
"id": "deepseek-reasoner",
|
|
"object": "model",
|
|
"owned_by": "deepseek",
|
|
},
|
|
],
|
|
}
|
|
return jsonify(models_list), 200
|
|
|
|
@self.app.route('/chat/completions', methods=['POST'])
|
|
def chat():
|
|
return self._handle_chat()
|
|
|
|
def _handle_chat(self):
|
|
"""Handle chat completion request"""
|
|
print("[DEBUG] _handle_chat called")
|
|
|
|
auth_header = request.headers.get('Authorization', '')
|
|
if not auth_header:
|
|
print("[DEBUG] No authorization header")
|
|
return jsonify({"error": "Authorization header required"}), 401
|
|
|
|
api_key = auth_header.replace('Bearer ', '').strip()
|
|
print(f"[DEBUG] API key (first 10 chars): {api_key[:10]}...")
|
|
|
|
# Validate API key is not empty
|
|
if not api_key:
|
|
print("[DEBUG] API key is empty")
|
|
return jsonify({"error": "API key cannot be empty. Please provide a valid Bearer token."}), 401
|
|
|
|
try:
|
|
data = request.get_json()
|
|
print(f"[DEBUG] Request data: {data}")
|
|
req = ChatCompletionRequest.from_dict(data)
|
|
print(f"[DEBUG] Parsed request: model={req.model}, stream={req.stream}, messages={len(req.messages)}, thinking_enabled={req.thinking_enabled}, search_enabled={req.search_enabled}")
|
|
except Exception as e:
|
|
print(f"[DEBUG] Failed to parse request: {e}")
|
|
return jsonify({"error": str(e)}), 400
|
|
|
|
print("[DEBUG] Creating API client")
|
|
api_client = Client(self.solver, api_key)
|
|
|
|
try:
|
|
print("[DEBUG] Getting chat data from cache")
|
|
chat_data = self.cache.get_chat_data(api_key, req.messages[0].content)
|
|
print(f"[DEBUG] Cache data: chat_id={chat_data.chat_id}, current_msg_id={chat_data.current_message_id}")
|
|
|
|
if not chat_data.chat_id:
|
|
print("[DEBUG] Creating new chat")
|
|
chat_data.chat_id = api_client.create_chat()
|
|
print(f"[DEBUG] Created chat: {chat_data.chat_id}")
|
|
else:
|
|
print("[DEBUG] Using existing chat")
|
|
if not chat_data.current_message_id:
|
|
chat_data.current_message_id = "0"
|
|
msg_id = int(chat_data.current_message_id)
|
|
msg_id += 2
|
|
chat_data.current_message_id = str(msg_id)
|
|
print(f"[DEBUG] Updated message ID: {chat_data.current_message_id}")
|
|
|
|
except Exception as e:
|
|
print(f"[DEBUG] Error in chat setup: {e}")
|
|
print(f"[DEBUG] Error traceback: {traceback.format_exc()}")
|
|
return jsonify({"error": str(e)}), 400
|
|
|
|
def save_and_change_title():
|
|
text = req.messages[0].content
|
|
# Special handling for title/follow-up/tags requests
|
|
if text.startswith("### Task:\nGenerate a concise, 3-5 word"):
|
|
text = f"title_req_{int(time.time())}"
|
|
elif text.startswith("### Task:\nSuggest 3-5"):
|
|
text = f"follow_req_{int(time.time())}"
|
|
elif text.startswith("### Task:\nGenerate 1-3 broad"):
|
|
text = f"tags_req_{int(time.time())}"
|
|
|
|
try:
|
|
api_client.change_title(chat_data.chat_id, text)
|
|
self.cache.set_chat_data(api_key, req.messages[0].content, chat_data)
|
|
except Exception as e:
|
|
print(f"Error saving chat data: {e}")
|
|
|
|
# Collect responses in a generator
|
|
def response_generator() -> Generator[str, None, None]:
|
|
responses = []
|
|
|
|
def collect_response(msg: str):
|
|
responses.append(msg)
|
|
|
|
try:
|
|
print(f"[DEBUG] Calling completion with thinking_enabled={req.thinking_enabled}, search_enabled={req.search_enabled}")
|
|
api_client.completion(
|
|
chat_data.chat_id,
|
|
chat_data.current_message_id,
|
|
req.messages[-1].content,
|
|
False,
|
|
False,
|
|
collect_response,
|
|
)
|
|
print(f"[DEBUG] Completion finished")
|
|
|
|
save_and_change_title()
|
|
|
|
if req.stream:
|
|
for msg in responses:
|
|
chunk = ChunkResponse(
|
|
id=f"chatcmpl-{random.randint(0, 1000000)}",
|
|
object="chat.completion.chunk",
|
|
created=int(time.time()),
|
|
model=req.model,
|
|
choices=[
|
|
ChunkChoice(
|
|
index=0,
|
|
delta=Delta(content=msg),
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
)
|
|
yield f"data: {json.dumps(chunk.to_dict())}\n\n"
|
|
time.sleep(random.uniform(0.1, 0.2))
|
|
|
|
yield "data: [DONE]\n\n"
|
|
else:
|
|
answer = "".join(responses)
|
|
response = ChatCompletionResponse(
|
|
id=f"chatcmpl-{random.randint(0, 1000000)}",
|
|
object="chat.completion",
|
|
created=int(time.time()),
|
|
model=req.model,
|
|
choices=[
|
|
Choice(
|
|
index=0,
|
|
message=Message(role="assistant", content=answer),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
)
|
|
yield json.dumps(response.to_dict())
|
|
|
|
except Exception as e:
|
|
error_tb = traceback.format_exc()
|
|
print(f"Error in completion: {e}")
|
|
print(error_tb)
|
|
if req.stream:
|
|
yield f"data: {json.dumps({'error': str(e), 'traceback': error_tb})}\n\n"
|
|
else:
|
|
yield json.dumps({"error": str(e), "traceback": error_tb})
|
|
|
|
if req.stream:
|
|
return Response(
|
|
stream_with_context(response_generator()),
|
|
mimetype='text/event-stream',
|
|
headers={
|
|
'Cache-Control': 'no-cache',
|
|
'Connection': 'keep-alive',
|
|
'Access-Control-Allow-Origin': '*',
|
|
},
|
|
)
|
|
else:
|
|
return Response(
|
|
response_generator(),
|
|
mimetype='application/json',
|
|
)
|
|
|
|
def run(self, host: str = "0.0.0.0", port: int = 8080, debug: bool = False):
|
|
"""Run the Flask application"""
|
|
self.app.run(host=host, port=port, debug=debug, threaded=True)
|
|
|
|
def close(self):
|
|
"""Close the application"""
|
|
self.cache.close()
|
|
self.solver.close()
|