Claude Comparet DB
Claude Comparet DB
import sys
import logging
import datetime
import psutil
from takeTime import Timer
import pretty_errors
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import fitz
from text_processing import clear_text, is_radar_header
from semantic_text_splitter import TextSplitter
from result_printer import res_printer
from sklearn.metrics.pairwise import cosine_similarity
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import re
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
logging.basicConfig(
filename="./data/logs/compareT_db_memory_usage.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
pretty_errors.configure(
display_timestamp=1,
timestamp_function=lambda: datetime.datetime.now().strftime("%Y-%m-%d %H:%M:
%S"),
lines_before=2,
lines_after=1,
display_locals=1,
)
def initialize_model():
"""Initialize the embedding model with optimal settings once globally."""
global EMBEDDING_MODEL
if EMBEDDING_MODEL is None:
# Check if GPU is available and set device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
EMBEDDING_MODEL = SentenceTransformer(
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
device=device
)
return EMBEDDING_MODEL
def log_memory_usage(step=""):
"""Logs the memory usage of the script with an optional step name."""
process = psutil.Process(os.getpid())
memory_used = process.memory_info().rss / 1024**2 # Convert bytes to MB
logging.info(f"Memory Usage after {step}: {memory_used:.2f} MB")
def extract_list_blocks_sem_chunker(file):
"""Extract text chunks from PDF using LangChain with optimized settings."""
log_memory_usage("Before extracting text")
try:
loader = PyPDFLoader(file_path)
# Extract documents
documents = loader.load_and_split(text_splitter=text_splitter)
# Free memory
del documents
del text_chunks
gc.collect()
return clean_chunks
except Exception as e:
logging.error(f"Error extracting text: {e}")
return []
def merge_text_if_needed(chunks):
"""
Optimize the text merging process by improving boundary detection
and reducing memory allocations.
"""
if not chunks:
return []
# Create a new list to avoid modifying the input list during iteration
merged_chunks = chunks.copy()
return merged_chunks
def get_optimal_batch_size(texts):
"""Dynamically determine optimal batch size based on text length and available
memory."""
avg_length = sum(len(text) for text in texts) / len(texts) if texts else 0
def get_embeddings_batch(batch):
"""Process a single batch of embeddings."""
model = initialize_model()
try:
embeddings = model.encode(batch, normalize_embeddings=True,
show_progress_bar=False)
return embeddings
except Exception as e:
logging.error(f"Error generating embeddings: {e}")
# Return empty embeddings for failed batch
return np.zeros((len(batch), model.get_sentence_embedding_dimension()))
Args:
texts (list of str): List of text sections to encode
max_workers (int): Maximum number of parallel workers
Returns:
numpy.ndarray: Array of embedding vectors
"""
# Ensure model is initialized first
initialize_model()
if not texts:
return []
# Create batches
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
all_embeddings = []
except Exception as e:
logging.error(f"Error processing batch {batch_idx}: {e}")
# Store results
for q_idx, d_idx in zip(*batch_matches):
sim_value = similarities[q_idx - i, d_idx]
if sim_value < 1.0: # Exclude exact matches (self-matches)
result_indices.append((q_idx, d_idx))
result_similarities.append(sim_value)
def db_connect():
"""Database connection function - imported from db_worker module."""
from db_worker import db_connect
return db_connect()
radar_file = sys.argv[1]
if not os.path.exists(radar_file):
logging.error("Compare - Radar file not found.")
sys.exit(1)
log_memory_usage("Start of main()")
radar_embeddings = get_embeddings_parallel(radar_blocks,
max_workers=num_workers)
log_memory_usage("After generating embeddings")
if not db_data_list:
logging.error("No embeddings retrieved from database.")
sys.exit(1)
# Output results
res_printer(results, "compareT_db", "RALF")
except Exception as e:
logging.error(f"Error during result processing or printing: {e}")
print(f"Error during result processing or printing: {e}")
if __name__ == "__main__":
main()