-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathapp.py
75 lines (58 loc) · 1.76 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from dataclasses import dataclass
from typing import Optional, List
from fastapi import FastAPI
import uvicorn
from fastapi.params import Query
from models.context_detector import ContextSimilarity, LinguisticAcceptability
from models.chatbot import Chatbot
from models.config import args
from models import grammar
@dataclass
class Response:
message: str
similarity: int
acceptability: int
personality: List[str]
# turn: Optional[int]
correction: str
changed: Optional[bool] = False
@dataclass
class Message:
user_input: str
app = FastAPI()
chatbot = Chatbot()
similarity = ContextSimilarity()
linguistic = LinguisticAcceptability()
@app.post("/message/")
async def message(item: Message):
raw_text = item.user_input
sentence = raw_text.strip()
message = chatbot.send(sentence)
human_history = chatbot.get_human_history()
gold_history = chatbot.get_gold_history()
similarity_score = similarity.predict(human_history, gold_history)
lang_score = linguistic.predict(human_history)
correction = grammar.correct(sentence)
response = Response(
message=message,
similarity=similarity_score,
acceptability=lang_score,
personality=chatbot.get_personality(),
correction=correction,
)
persona_string = '\n'.join(chatbot.get_personality())
print(f"Current Persona: {persona_string}")
return response
@app.get("/personality/")
async def read_persona():
return chatbot.get_personality()
@app.get('/personality/shuffle/')
async def shuffle_persona():
chatbot.shuffle()
return "Success"
@app.get('history/clear')
async def clear_history():
chatbot.clear_history()
return "Success"
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)