-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsbert.py
93 lines (74 loc) · 2.99 KB
/
sbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from fastapi import FastAPI
import uvicorn
import sys
import argparse
import os
from sentence_transformers import SentenceTransformer
from fastapi import HTTPException
from pydantic import BaseModel
from typing import List
import sentence_transformers_pool
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
transformers_cache = os.environ.get('TRANSFORMERS_CACHE')
# Pool dictionary will create a pool of sentence transformers for each model
# we need to manage not too many models at the same time due to memory constraints
# for this example we will not limit the number of models
transformers_dict = {}
app = FastAPI()
class SentenceInput(BaseModel):
sentences: List[str]
modelName: str = "all-mpnet-base-v2"
class DimensionsInput(BaseModel):
modelName: str = "all-mpnet-base-v2"
class CountTokenInput(BaseModel):
modelName: str = "all-mpnet-base-v2"
text: str
def load_model(model_name: str) -> SentenceTransformer:
if model_name not in transformers_dict:
transformers_dict[model_name] = sentence_transformers_pool.SentenceTransformerPool(model_name=model_name, max_size=3)
return transformers_dict[model_name].get()
def release_model(model_name: str, model: SentenceTransformer) -> None:
transformers_dict[model_name].release(model)
@app.get("/ping")
async def ping():
return "ping"
@app.post("/dimensions")
async def dimensions(input_data: DimensionsInput):
model = load_model(input_data.modelName)
try:
return { "model" : input_data.modelName,
"maxSequenceLength" : model.max_seq_length,
"dimension" : model.get_sentence_embedding_dimension() }
finally:
release_model(input_data.modelName, model)
@app.post("/process-sentences")
async def process_sentences(input_data: SentenceInput):
sentences = input_data.sentences
model = load_model(input_data.modelName)
try:
encoded = model.encode(sentences)
return {"embeddings": encoded.tolist(), "model" : input_data.modelName}
finally:
release_model(input_data.modelName, model)
@app.post("/count-tokens")
async def count_tokens(input_data: CountTokenInput):
model = load_model(input_data.modelName)
try:
tokenizer = model.tokenizer
tokens = tokenizer.tokenize(input_data.text)
return { "count" : len(tokens) }
finally:
logging.info("Releasing model")
release_model(input_data.modelName, model)
if __name__ == "__main__":
# Create an argument parser
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000, help="Specify the port number")
# Parse the command-line arguments
args = parser.parse_args()
# Get the port number from the parsed arguments
port = args.port
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=port)