-
-
Notifications
You must be signed in to change notification settings - Fork 16
[#441] Preload embedding model at startup #461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
59b40f0
3ffb74a
12b09a7
da9afaa
5ce7782
50a8bd3
795f218
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,3 +4,7 @@ | |
| class ApiConfig(AppConfig): | ||
| default_auto_field = 'django.db.models.BigAutoField' | ||
| name = 'api' | ||
|
|
||
| def ready(self): | ||
| from .services.sentencetTransformer_model import TransformerModel | ||
| TransformerModel.get_instance() | ||
|
Comment on lines
+8
to
+10
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,18 +11,17 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def get_closest_embeddings( | ||
| user, message_data, document_name=None, guid=None, num_results=10 | ||
| ): | ||
|
|
||
| def build_query(user, embedding_vector, document_name=None, guid=None, num_results=10): | ||
| """ | ||
| Find the closest embeddings to a given message for a specific user. | ||
| Build an unevaluated QuerySet for the closest embeddings. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| user : User | ||
| The user whose uploaded documents will be searched | ||
| message_data : str | ||
| The input message to find similar embeddings for | ||
| embedding_vector : array-like | ||
| Pre-computed embedding vector to compare against | ||
| document_name : str, optional | ||
| Filter results to a specific document name | ||
| guid : str, optional | ||
|
|
@@ -32,59 +31,52 @@ def get_closest_embeddings( | |
|
|
||
| Returns | ||
| ------- | ||
| list[dict] | ||
| List of dictionaries containing embedding results with keys: | ||
| - name: document name | ||
| - text: embedded text content | ||
| - page_number: page number in source document | ||
| - chunk_number: chunk number within the document | ||
| - distance: L2 distance from query embedding | ||
| - file_id: GUID of the source file | ||
| QuerySet | ||
| Unevaluated Django QuerySet ordered by L2 distance, sliced to num_results | ||
| """ | ||
|
|
||
| encoding_start = time.time() | ||
| transformerModel = TransformerModel.get_instance().model | ||
| embedding_message = transformerModel.encode(message_data) | ||
| encoding_time = time.time() - encoding_start | ||
|
|
||
| db_query_start = time.time() | ||
|
|
||
| # Django QuerySets are lazily evaluated | ||
| if user.is_authenticated: | ||
| # User sees their own files + files uploaded by superusers | ||
| closest_embeddings_query = ( | ||
| Embeddings.objects.filter( | ||
| Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) | ||
| ) | ||
| .annotate( | ||
| distance=L2Distance("embedding_sentence_transformers", embedding_message) | ||
| ) | ||
| .order_by("distance") | ||
| queryset = Embeddings.objects.filter( | ||
| Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) | ||
| ) | ||
| else: | ||
| # Unauthenticated users only see superuser-uploaded files | ||
| closest_embeddings_query = ( | ||
| Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) | ||
| .annotate( | ||
| distance=L2Distance("embedding_sentence_transformers", embedding_message) | ||
| ) | ||
| .order_by("distance") | ||
| ) | ||
| queryset = Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) | ||
|
|
||
| queryset = ( | ||
| queryset | ||
| .annotate(distance=L2Distance("embedding_sentence_transformers", embedding_vector)) | ||
| .order_by("distance") | ||
| ) | ||
|
|
||
| # Filtering to a document GUID takes precedence over a document name | ||
| if guid: | ||
| closest_embeddings_query = closest_embeddings_query.filter( | ||
| upload_file__guid=guid | ||
| ) | ||
| queryset = queryset.filter(upload_file__guid=guid) | ||
| elif document_name: | ||
| closest_embeddings_query = closest_embeddings_query.filter(name=document_name) | ||
| queryset = queryset.filter(name=document_name) | ||
|
|
||
| # Slicing is equivalent to SQL's LIMIT clause | ||
| closest_embeddings_query = closest_embeddings_query[:num_results] | ||
| return queryset[:num_results] | ||
|
Comment on lines
15
to
+60
|
||
|
|
||
|
|
||
| def evaluate_query(queryset): | ||
| """ | ||
| Evaluate a QuerySet and return a list of result dicts. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| queryset : iterable | ||
| Iterable of Embeddings objects (or any objects with the expected attributes) | ||
|
|
||
| Returns | ||
| ------- | ||
| list[dict] | ||
| List of dicts with keys: name, text, page_number, chunk_number, distance, file_id | ||
| """ | ||
| # Iterating evaluates the QuerySet and hits the database | ||
| # TODO: Research improving the query evaluation performance | ||
| results = [ | ||
| return [ | ||
| { | ||
| "name": obj.name, | ||
| "text": obj.text, | ||
|
|
@@ -93,13 +85,36 @@ def get_closest_embeddings( | |
| "distance": obj.distance, | ||
| "file_id": obj.upload_file.guid if obj.upload_file else None, | ||
| } | ||
| for obj in closest_embeddings_query | ||
| for obj in queryset | ||
| ] | ||
|
|
||
| db_query_time = time.time() - db_query_start | ||
|
|
||
| def log_usage( | ||
| results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time | ||
| ): | ||
| """ | ||
| Create a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| results : list[dict] | ||
| The search results, each containing a "distance" key | ||
| message_data : str | ||
| The original search query text | ||
| user : User | ||
| The user who performed the search | ||
| guid : str or None | ||
| Document GUID filter used in the search | ||
| document_name : str or None | ||
| Document name filter used in the search | ||
| num_results : int | ||
| Number of results requested | ||
| encoding_time : float | ||
| Time in seconds to encode the query | ||
| db_query_time : float | ||
| Time in seconds for the database query | ||
| """ | ||
| try: | ||
| # Handle user having no uploaded docs or doc filtering returning no matches | ||
| if results: | ||
| distances = [r["distance"] for r in results] | ||
| SemanticSearchUsage.objects.create( | ||
|
|
@@ -113,11 +128,10 @@ def get_closest_embeddings( | |
| num_results_returned=len(results), | ||
| max_distance=max(distances), | ||
| median_distance=median(distances), | ||
| min_distance=min(distances) | ||
| min_distance=min(distances), | ||
| ) | ||
| else: | ||
| logger.warning("Semantic search returned no results") | ||
|
|
||
| SemanticSearchUsage.objects.create( | ||
| query_text=message_data, | ||
| user=user if (user and user.is_authenticated) else None, | ||
|
|
@@ -129,9 +143,58 @@ def get_closest_embeddings( | |
| num_results_returned=0, | ||
| max_distance=None, | ||
| median_distance=None, | ||
| min_distance=None | ||
| min_distance=None, | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"Failed to create semantic search usage database record: {e}") | ||
|
|
||
|
|
||
| def get_closest_embeddings( | ||
| user, message_data, document_name=None, guid=None, num_results=10 | ||
| ): | ||
| """ | ||
| Find the closest embeddings to a given message for a specific user. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| user : User | ||
| The user whose uploaded documents will be searched | ||
| message_data : str | ||
| The input message to find similar embeddings for | ||
| document_name : str, optional | ||
| Filter results to a specific document name | ||
| guid : str, optional | ||
| Filter results to a specific document GUID (takes precedence over document_name) | ||
| num_results : int, default 10 | ||
| Maximum number of results to return | ||
|
|
||
| Returns | ||
| ------- | ||
| list[dict] | ||
| List of dictionaries containing embedding results with keys: | ||
| - name: document name | ||
| - text: embedded text content | ||
| - page_number: page number in source document | ||
| - chunk_number: chunk number within the document | ||
| - distance: L2 distance from query embedding | ||
| - file_id: GUID of the source file | ||
|
|
||
| Notes | ||
| ----- | ||
| Creates a SemanticSearchUsage record. Swallows exceptions so search isn't interrupted. | ||
| """ | ||
| encoding_start = time.time() | ||
| model = TransformerModel.get_instance().model | ||
| embedding_vector = model.encode(message_data) | ||
| encoding_time = time.time() - encoding_start | ||
|
|
||
| db_query_start = time.time() | ||
| queryset = build_query(user, embedding_vector, document_name, guid, num_results) | ||
| results = evaluate_query(queryset) | ||
| db_query_time = time.time() - db_query_start | ||
|
|
||
| log_usage( | ||
| results, message_data, user, guid, document_name, num_results, encoding_time, db_query_time | ||
| ) | ||
|
|
||
| return results | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from api.services.embedding_services import evaluate_query, log_usage | ||
|
|
||
|
|
||
| def test_evaluate_query_maps_fields(): | ||
| obj = MagicMock() | ||
| obj.name = "doc.pdf" | ||
| obj.text = "some text" | ||
| obj.page_num = 3 | ||
| obj.chunk_number = 1 | ||
| obj.distance = 0.42 | ||
| obj.upload_file.guid = "abc-123" | ||
|
|
||
| results = evaluate_query([obj]) | ||
|
|
||
| assert results == [ | ||
| { | ||
| "name": "doc.pdf", | ||
| "text": "some text", | ||
| "page_number": 3, | ||
| "chunk_number": 1, | ||
| "distance": 0.42, | ||
| "file_id": "abc-123", | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| def test_evaluate_query_none_upload_file(): | ||
| obj = MagicMock() | ||
| obj.name = "doc.pdf" | ||
| obj.text = "some text" | ||
| obj.page_num = 1 | ||
| obj.chunk_number = 0 | ||
| obj.distance = 1.0 | ||
| obj.upload_file = None | ||
|
|
||
| results = evaluate_query([obj]) | ||
|
|
||
| assert results[0]["file_id"] is None | ||
|
|
||
|
|
||
| @patch("api.services.embedding_services.SemanticSearchUsage.objects.create") | ||
| def test_log_usage_computes_distance_stats(mock_create): | ||
| results = [{"distance": 1.0}, {"distance": 3.0}, {"distance": 2.0}] | ||
| user = MagicMock(is_authenticated=True) | ||
|
|
||
| log_usage( | ||
| results, | ||
| message_data="test query", | ||
| user=user, | ||
| guid=None, | ||
| document_name=None, | ||
| num_results=10, | ||
| encoding_time=0.1, | ||
| db_query_time=0.2, | ||
| ) | ||
|
|
||
| mock_create.assert_called_once() | ||
| kwargs = mock_create.call_args.kwargs | ||
| assert kwargs["min_distance"] == 1.0 | ||
| assert kwargs["max_distance"] == 3.0 | ||
| assert kwargs["median_distance"] == 2.0 | ||
| assert kwargs["num_results_returned"] == 3 | ||
|
|
||
|
|
||
| @patch( | ||
| "api.services.embedding_services.SemanticSearchUsage.objects.create", | ||
| side_effect=Exception("DB error"), | ||
| ) | ||
| def test_log_usage_swallows_exceptions(mock_create): | ||
| results = [{"distance": 1.0}] | ||
| user = MagicMock(is_authenticated=True) | ||
|
|
||
| # pytest fails the test if it catches unhandled Exception | ||
| log_usage( | ||
| results, | ||
| message_data="test query", | ||
| user=user, | ||
| guid=None, | ||
| document_name=None, | ||
| num_results=10, | ||
| encoding_time=0.1, | ||
| db_query_time=0.2, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| [pytest] | ||
| DJANGO_SETTINGS_MODULE = balancer_backend.settings | ||
| pythonpath = . |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,4 +19,6 @@ PyMuPDF==1.24.0 | |
| Pillow | ||
| pytesseract | ||
| anthropic | ||
| drf-spectacular | ||
| pytest | ||
| pytest-django | ||
| drf-spectacular | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest.iniwas added underserver/, but this workflow runspytest server/ -vfrom the repo root. Pytest won’t automatically discover config files in subdirectories, soDJANGO_SETTINGS_MODULE/pythonpathmay not be applied and Django tests can fail to initialize. Fix by eithercd server && pytest -vor runningpytest -c server/pytest.ini server/ -v(or settingDJANGO_SETTINGS_MODULEin the workflow env).