Skip to content

searchers

pgvector-backed SNOMED searchers for attribute and reference retrieval.

AbstractSnomedSearcher

Bases: ABC

ABC for pgvector-backed SNOMED searchers.

Enforces lifecycle management (close / context-manager) and cost tracking — matching the patterns used by PgvectorConceptSearcher and LlmMapper.

Subclasses must implement :meth:search.

Source code in src/ariadne/hierarchy/searchers.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class AbstractSnomedSearcher(ABC):
    """ABC for pgvector-backed SNOMED searchers.

    Enforces lifecycle management (``close`` / context-manager) and cost
    tracking — matching the patterns used by ``PgvectorConceptSearcher`` and
    ``LlmMapper``.

    Subclasses must implement :meth:`search`.
    """

    def __init__(self, cfg: HierarchySettings | None = None):
        if cfg is None:
            cfg = load_hierarchy_settings()
        self.cfg = cfg
        conn_str = get_environment_variable("VOCAB_CONNECTION_STRING")
        conn_str = conn_str.replace("+psycopg", "").replace("+psycopg2", "")
        self.connection = psycopg.connect(conn_str, autocommit=True)
        try:
            register_vector(self.connection)
            with self.connection.cursor() as cur:
                cur.execute(f"SET hnsw.ef_search = {self.cfg.retrieval.hnsw_ef_search}")
            self.schema = get_environment_variable("VOCAB_SCHEMA")
        except Exception:
            self.connection.close()
            raise
        self._cost = 0.0

    # -- abstract contract --------------------------------------------------

    @abstractmethod
    def search(self, text: str, *args, **kwargs):
        """Run a similarity search.  Signature varies by subclass."""

    # -- lifecycle ----------------------------------------------------------

    def get_total_cost(self) -> float:
        """Return accumulated embedding cost (USD)."""
        return self._cost

    def close(self):
        """Close the database connection."""
        self.connection.close()

    def __enter__(self):
        return self

    def __exit__(self, *exc):
        self.close()

close()

Close the database connection.

Source code in src/ariadne/hierarchy/searchers.py
116
117
118
def close(self):
    """Close the database connection."""
    self.connection.close()

get_total_cost()

Return accumulated embedding cost (USD).

Source code in src/ariadne/hierarchy/searchers.py
112
113
114
def get_total_cost(self) -> float:
    """Return accumulated embedding cost (USD)."""
    return self._cost

search(text, *args, **kwargs) abstractmethod

Run a similarity search. Signature varies by subclass.

Source code in src/ariadne/hierarchy/searchers.py
106
107
108
@abstractmethod
def search(self, text: str, *args, **kwargs):
    """Run a similarity search.  Signature varies by subclass."""

SnomedAttributeSearcher

Bases: AbstractSnomedSearcher

Source code in src/ariadne/hierarchy/searchers.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class SnomedAttributeSearcher(AbstractSnomedSearcher):

    def search(self, text: str, category_name: str, top_k: int | None = None) -> SearchResult:
        """Embed *text* and return the closest concepts in the given attribute category.

        Args:
            text: Free-text value to embed and search for.
            category_name: SNOMED attribute category filter.
            top_k: Maximum number of results (defaults to ``cfg.retrieval.top_k_per_category``).

        Returns:
            SearchResult(dataframe, cost).
        """
        top_k = top_k if top_k is not None else self.cfg.retrieval.top_k_per_category
        result = get_embedding_vectors([text])
        vector = result["embeddings"][0]
        cost = result["usage"]["total_cost_usd"]
        self._cost += cost

        query = sql.SQL("""
            WITH q AS (SELECT %s::vector AS vec)
            SELECT concept_id, concept_code, concept_name, attribute_category,
                   1 - (embedding <=> q.vec) AS similarity
            FROM {schema}.{table}, q
            WHERE attribute_category = %s
            ORDER BY embedding <=> q.vec
            LIMIT %s
        """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_attribute"))
        with self.connection.cursor() as cur:
            cur.execute(query, (vector.tolist(), category_name, top_k))
            rows = cur.fetchall()

        if not rows:
            return SearchResult(pd.DataFrame(columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]), cost)

        return SearchResult(pd.DataFrame(rows, columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]), cost)

    def search_batch(
        self,
        texts_and_categories: list[tuple[str, str, str]],
        top_k: int | None = None,
    ) -> SearchBatchResult:
        """Embed all *texts* in a single API call and run per-category queries.

        Args:
            texts_and_categories: List of ``(attr_key, text, category_name)`` tuples.
            top_k: Number of candidates per category (defaults to ``cfg.retrieval.top_k_per_category``).

        Returns:
            SearchBatchResult(results, total_cost).
        """
        top_k = top_k if top_k is not None else self.cfg.retrieval.top_k_per_category
        if not texts_and_categories:
            return SearchBatchResult({}, 0.0)

        texts = [t[1] for t in texts_and_categories]
        result = get_embedding_vectors(texts)
        vectors = result["embeddings"]
        cost = result["usage"]["total_cost_usd"]
        self._cost += cost

        results: dict[str, pd.DataFrame] = {}
        for (attr_key, _text, category_name), vector in zip(texts_and_categories, vectors):
            query = sql.SQL("""
                WITH q AS (SELECT %s::vector AS vec)
                SELECT concept_id, concept_code, concept_name, attribute_category,
                       1 - (embedding <=> q.vec) AS similarity
                FROM {schema}.{table}, q
                WHERE attribute_category = %s
                ORDER BY embedding <=> q.vec
                LIMIT %s
            """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_attribute"))
            with self.connection.cursor() as cur:
                cur.execute(query, (vector.tolist(), category_name, top_k))
                rows = cur.fetchall()

            if rows:
                results[attr_key] = pd.DataFrame(
                    rows, columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]
                )
            else:
                results[attr_key] = pd.DataFrame(
                    columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]
                )
        return SearchBatchResult(results, cost)

    def expand_via_hierarchy(
        self,
        concept_ids: list[int],
        attribute_category: str,
    ) -> pd.DataFrame:
        """Expand candidate concept IDs by 1-hop SNOMED hierarchy.

        For each *concept_id*, retrieves its immediate parents (``Is a``) and
        children (``Subsumes``) from ``concept_relationship``, then filters to
        those that exist in ``snomed_attribute`` under the same
        *attribute_category*.  No embedding API call is needed.

        Args:
            concept_ids: Seed concept IDs to expand.
            attribute_category: SNOMED attribute category filter.

        Returns:
            DataFrame with columns
            ``[concept_id, concept_name, attribute_category, similarity]``
            where *similarity* is set to ``cfg.scoring.hierarchy_similarity``.
        """
        if not concept_ids:
            return pd.DataFrame(
                columns=["concept_id", "concept_name", "attribute_category", "similarity"]
            )

        query = sql.SQL("""
            WITH seeds AS (
                SELECT unnest(%s::int[]) AS cid
            ),
            neighbors AS (
                -- parents (seed "Is a" parent)
                SELECT DISTINCT cr.concept_id_2 AS cid
                FROM {schema}.{concept_rel} cr
                JOIN seeds s ON cr.concept_id_1 = s.cid
                WHERE cr.relationship_id = 'Is a'
                  AND cr.invalid_reason IS NULL
                UNION
                -- children (child "Is a" seed)
                SELECT DISTINCT cr.concept_id_1 AS cid
                FROM {schema}.{concept_rel} cr
                JOIN seeds s ON cr.concept_id_2 = s.cid
                WHERE cr.relationship_id = 'Is a'
                  AND cr.invalid_reason IS NULL
            )
            SELECT DISTINCT sa.concept_id, sa.concept_code, sa.concept_name, sa.attribute_category
            FROM {schema}.{snomed_attr} sa
            JOIN neighbors n ON sa.concept_id = n.cid
            WHERE sa.attribute_category = %s
              AND sa.concept_id NOT IN (SELECT cid FROM seeds)
        """).format(
            schema=sql.Identifier(self.schema),
            concept_rel=sql.Identifier("concept_relationship"),
            snomed_attr=sql.Identifier("snomed_attribute"),
        )
        params = [concept_ids, attribute_category]
        with self.connection.cursor() as cur:
            cur.execute(query, params)
            rows = cur.fetchall()

        if not rows:
            return pd.DataFrame(
                columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]
            )

        df = pd.DataFrame(rows, columns=["concept_id", "concept_code", "concept_name", "attribute_category"])
        df = df.drop_duplicates(subset=["concept_id"])
        df["similarity"] = self.cfg.scoring.hierarchy_similarity
        return df

expand_via_hierarchy(concept_ids, attribute_category)

Expand candidate concept IDs by 1-hop SNOMED hierarchy.

For each concept_id, retrieves its immediate parents (Is a) and children (Subsumes) from concept_relationship, then filters to those that exist in snomed_attribute under the same attribute_category. No embedding API call is needed.

Parameters:

Name Type Description Default
concept_ids list[int]

Seed concept IDs to expand.

required
attribute_category str

SNOMED attribute category filter.

required

Returns:

Type Description
DataFrame

DataFrame with columns

DataFrame

[concept_id, concept_name, attribute_category, similarity]

DataFrame

where similarity is set to cfg.scoring.hierarchy_similarity.

Source code in src/ariadne/hierarchy/searchers.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def expand_via_hierarchy(
    self,
    concept_ids: list[int],
    attribute_category: str,
) -> pd.DataFrame:
    """Expand candidate concept IDs by 1-hop SNOMED hierarchy.

    For each *concept_id*, retrieves its immediate parents (``Is a``) and
    children (``Subsumes``) from ``concept_relationship``, then filters to
    those that exist in ``snomed_attribute`` under the same
    *attribute_category*.  No embedding API call is needed.

    Args:
        concept_ids: Seed concept IDs to expand.
        attribute_category: SNOMED attribute category filter.

    Returns:
        DataFrame with columns
        ``[concept_id, concept_name, attribute_category, similarity]``
        where *similarity* is set to ``cfg.scoring.hierarchy_similarity``.
    """
    if not concept_ids:
        return pd.DataFrame(
            columns=["concept_id", "concept_name", "attribute_category", "similarity"]
        )

    query = sql.SQL("""
        WITH seeds AS (
            SELECT unnest(%s::int[]) AS cid
        ),
        neighbors AS (
            -- parents (seed "Is a" parent)
            SELECT DISTINCT cr.concept_id_2 AS cid
            FROM {schema}.{concept_rel} cr
            JOIN seeds s ON cr.concept_id_1 = s.cid
            WHERE cr.relationship_id = 'Is a'
              AND cr.invalid_reason IS NULL
            UNION
            -- children (child "Is a" seed)
            SELECT DISTINCT cr.concept_id_1 AS cid
            FROM {schema}.{concept_rel} cr
            JOIN seeds s ON cr.concept_id_2 = s.cid
            WHERE cr.relationship_id = 'Is a'
              AND cr.invalid_reason IS NULL
        )
        SELECT DISTINCT sa.concept_id, sa.concept_code, sa.concept_name, sa.attribute_category
        FROM {schema}.{snomed_attr} sa
        JOIN neighbors n ON sa.concept_id = n.cid
        WHERE sa.attribute_category = %s
          AND sa.concept_id NOT IN (SELECT cid FROM seeds)
    """).format(
        schema=sql.Identifier(self.schema),
        concept_rel=sql.Identifier("concept_relationship"),
        snomed_attr=sql.Identifier("snomed_attribute"),
    )
    params = [concept_ids, attribute_category]
    with self.connection.cursor() as cur:
        cur.execute(query, params)
        rows = cur.fetchall()

    if not rows:
        return pd.DataFrame(
            columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]
        )

    df = pd.DataFrame(rows, columns=["concept_id", "concept_code", "concept_name", "attribute_category"])
    df = df.drop_duplicates(subset=["concept_id"])
    df["similarity"] = self.cfg.scoring.hierarchy_similarity
    return df

search(text, category_name, top_k=None)

Embed text and return the closest concepts in the given attribute category.

Parameters:

Name Type Description Default
text str

Free-text value to embed and search for.

required
category_name str

SNOMED attribute category filter.

required
top_k int | None

Maximum number of results (defaults to cfg.retrieval.top_k_per_category).

None

Returns:

Type Description
SearchResult

SearchResult(dataframe, cost).

Source code in src/ariadne/hierarchy/searchers.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def search(self, text: str, category_name: str, top_k: int | None = None) -> SearchResult:
    """Embed *text* and return the closest concepts in the given attribute category.

    Args:
        text: Free-text value to embed and search for.
        category_name: SNOMED attribute category filter.
        top_k: Maximum number of results (defaults to ``cfg.retrieval.top_k_per_category``).

    Returns:
        SearchResult(dataframe, cost).
    """
    top_k = top_k if top_k is not None else self.cfg.retrieval.top_k_per_category
    result = get_embedding_vectors([text])
    vector = result["embeddings"][0]
    cost = result["usage"]["total_cost_usd"]
    self._cost += cost

    query = sql.SQL("""
        WITH q AS (SELECT %s::vector AS vec)
        SELECT concept_id, concept_code, concept_name, attribute_category,
               1 - (embedding <=> q.vec) AS similarity
        FROM {schema}.{table}, q
        WHERE attribute_category = %s
        ORDER BY embedding <=> q.vec
        LIMIT %s
    """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_attribute"))
    with self.connection.cursor() as cur:
        cur.execute(query, (vector.tolist(), category_name, top_k))
        rows = cur.fetchall()

    if not rows:
        return SearchResult(pd.DataFrame(columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]), cost)

    return SearchResult(pd.DataFrame(rows, columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]), cost)

search_batch(texts_and_categories, top_k=None)

Embed all texts in a single API call and run per-category queries.

Parameters:

Name Type Description Default
texts_and_categories list[tuple[str, str, str]]

List of (attr_key, text, category_name) tuples.

required
top_k int | None

Number of candidates per category (defaults to cfg.retrieval.top_k_per_category).

None

Returns:

Type Description
SearchBatchResult

SearchBatchResult(results, total_cost).

Source code in src/ariadne/hierarchy/searchers.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def search_batch(
    self,
    texts_and_categories: list[tuple[str, str, str]],
    top_k: int | None = None,
) -> SearchBatchResult:
    """Embed all *texts* in a single API call and run per-category queries.

    Args:
        texts_and_categories: List of ``(attr_key, text, category_name)`` tuples.
        top_k: Number of candidates per category (defaults to ``cfg.retrieval.top_k_per_category``).

    Returns:
        SearchBatchResult(results, total_cost).
    """
    top_k = top_k if top_k is not None else self.cfg.retrieval.top_k_per_category
    if not texts_and_categories:
        return SearchBatchResult({}, 0.0)

    texts = [t[1] for t in texts_and_categories]
    result = get_embedding_vectors(texts)
    vectors = result["embeddings"]
    cost = result["usage"]["total_cost_usd"]
    self._cost += cost

    results: dict[str, pd.DataFrame] = {}
    for (attr_key, _text, category_name), vector in zip(texts_and_categories, vectors):
        query = sql.SQL("""
            WITH q AS (SELECT %s::vector AS vec)
            SELECT concept_id, concept_code, concept_name, attribute_category,
                   1 - (embedding <=> q.vec) AS similarity
            FROM {schema}.{table}, q
            WHERE attribute_category = %s
            ORDER BY embedding <=> q.vec
            LIMIT %s
        """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_attribute"))
        with self.connection.cursor() as cur:
            cur.execute(query, (vector.tolist(), category_name, top_k))
            rows = cur.fetchall()

        if rows:
            results[attr_key] = pd.DataFrame(
                rows, columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]
            )
        else:
            results[attr_key] = pd.DataFrame(
                columns=["concept_id", "concept_code", "concept_name", "attribute_category", "similarity"]
            )
    return SearchBatchResult(results, cost)

SnomedReferenceSearcher

Bases: AbstractSnomedSearcher

Source code in src/ariadne/hierarchy/searchers.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
class SnomedReferenceSearcher(AbstractSnomedSearcher):

    def __init__(self, cfg: HierarchySettings | None = None,
                 exclude_concept_ids: set[int] | None = None):
        super().__init__(cfg)
        self._exclude_concept_ids: list[int] = sorted(exclude_concept_ids) if exclude_concept_ids else []
        if self._exclude_concept_ids:
            logger.info(
                "SnomedReferenceSearcher: excluding %d concept IDs from results.",
                len(self._exclude_concept_ids),
            )

    def search(
        self,
        text: str,
        top_k: int | None = None,
        embedding: "np.ndarray | None" = None,
    ) -> ReferenceSearchResult:
        """Embed *text* and return similar reference terms with their attributes.

        Args:
            text: Free-text term to embed (ignored when *embedding* is supplied).
            top_k: Maximum number of reference concepts to return
                (defaults to ``cfg.retrieval.num_reference_examples``).
            embedding: Optional precomputed embedding vector (shape ``[dim]``).
                When provided the API call is skipped and cost is recorded as
                zero.  Pass the vector produced by
                ``PgvectorConceptSearcher.search_terms(..., return_embeddings=True)``
                to avoid re-embedding the same term in Step 2.

        Returns:
            ReferenceSearchResult(examples, cost).
        """
        top_k = top_k if top_k is not None else self.cfg.retrieval.num_reference_examples
        if embedding is not None:
            import numpy as np  # local import — only needed here
            vector = np.asarray(embedding, dtype=float)
            cost = 0.0
        else:
            result = get_embedding_vectors([text])
            vector = result["embeddings"][0]
            cost = result["usage"]["total_cost_usd"]
            self._cost += cost

        inner_limit = top_k * 10
        if self._exclude_concept_ids:
            query = sql.SQL("""
                WITH q AS (SELECT %s::vector AS vec),
                     ranked AS (
                    SELECT concept_id_1, concept_name_1,
                           1 - (embedding <=> q.vec) AS similarity,
                           ROW_NUMBER() OVER (
                               PARTITION BY concept_id_1
                               ORDER BY embedding <=> q.vec
                           ) AS rn
                    FROM {schema}.{table}, q
                    WHERE concept_id_1 != ALL(%s)
                    ORDER BY embedding <=> q.vec
                    LIMIT %s
                )
                SELECT concept_id_1, concept_name_1, similarity
                FROM ranked
                WHERE rn = 1
                ORDER BY similarity DESC
                LIMIT %s
            """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_reference"))
            with self.connection.cursor() as cur:
                cur.execute(query, (vector.tolist(), self._exclude_concept_ids, inner_limit, top_k))
                top_terms = cur.fetchall()
        else:
            query = sql.SQL("""
                WITH q AS (SELECT %s::vector AS vec),
                     ranked AS (
                    SELECT concept_id_1, concept_name_1,
                           1 - (embedding <=> q.vec) AS similarity,
                           ROW_NUMBER() OVER (
                               PARTITION BY concept_id_1
                               ORDER BY embedding <=> q.vec
                           ) AS rn
                    FROM {schema}.{table}, q
                    ORDER BY embedding <=> q.vec
                    LIMIT %s
                )
                SELECT concept_id_1, concept_name_1, similarity
                FROM ranked
                WHERE rn = 1
                ORDER BY similarity DESC
                LIMIT %s
            """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_reference"))
            with self.connection.cursor() as cur:
                cur.execute(query, (vector.tolist(), inner_limit, top_k))
                top_terms = cur.fetchall()

        if not top_terms:
            return ReferenceSearchResult([], cost)

        # Fetch all attribute rows for those concept_id_1 values
        ids = [r[0] for r in top_terms]
        attr_query = sql.SQL("""
            SELECT concept_id_1, concept_id_2, concept_code_2, concept_name_2, attribute_category
            FROM {schema}.{table}
            WHERE concept_id_1 = ANY(%s)
        """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_reference"))
        with self.connection.cursor() as cur:
            cur.execute(attr_query, (ids,))
            attr_rows = cur.fetchall()

        # Group attributes by concept_id_1
        attrs_by_id: dict[int, list] = {}
        for concept_id_1, concept_id_2, concept_code_2, concept_name_2, attribute_category in attr_rows:
            attrs_by_id.setdefault(concept_id_1, []).append({
                "concept_id_2": concept_id_2,
                "concept_code_2": concept_code_2,
                "concept_name_2": concept_name_2,
                "attribute_category": attribute_category,
            })

        return ReferenceSearchResult([
            {
                "concept_id": concept_id_1,
                "concept_name": concept_name_1,
                "similarity": similarity,
                "attributes": attrs_by_id.get(concept_id_1, []),
            }
            for concept_id_1, concept_name_1, similarity in top_terms
        ], cost)

search(text, top_k=None, embedding=None)

Embed text and return similar reference terms with their attributes.

Parameters:

Name Type Description Default
text str

Free-text term to embed (ignored when embedding is supplied).

required
top_k int | None

Maximum number of reference concepts to return (defaults to cfg.retrieval.num_reference_examples).

None
embedding ndarray | None

Optional precomputed embedding vector (shape [dim]). When provided the API call is skipped and cost is recorded as zero. Pass the vector produced by PgvectorConceptSearcher.search_terms(..., return_embeddings=True) to avoid re-embedding the same term in Step 2.

None

Returns:

Type Description
ReferenceSearchResult

ReferenceSearchResult(examples, cost).

Source code in src/ariadne/hierarchy/searchers.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def search(
    self,
    text: str,
    top_k: int | None = None,
    embedding: "np.ndarray | None" = None,
) -> ReferenceSearchResult:
    """Embed *text* and return similar reference terms with their attributes.

    Args:
        text: Free-text term to embed (ignored when *embedding* is supplied).
        top_k: Maximum number of reference concepts to return
            (defaults to ``cfg.retrieval.num_reference_examples``).
        embedding: Optional precomputed embedding vector (shape ``[dim]``).
            When provided the API call is skipped and cost is recorded as
            zero.  Pass the vector produced by
            ``PgvectorConceptSearcher.search_terms(..., return_embeddings=True)``
            to avoid re-embedding the same term in Step 2.

    Returns:
        ReferenceSearchResult(examples, cost).
    """
    top_k = top_k if top_k is not None else self.cfg.retrieval.num_reference_examples
    if embedding is not None:
        import numpy as np  # local import — only needed here
        vector = np.asarray(embedding, dtype=float)
        cost = 0.0
    else:
        result = get_embedding_vectors([text])
        vector = result["embeddings"][0]
        cost = result["usage"]["total_cost_usd"]
        self._cost += cost

    inner_limit = top_k * 10
    if self._exclude_concept_ids:
        query = sql.SQL("""
            WITH q AS (SELECT %s::vector AS vec),
                 ranked AS (
                SELECT concept_id_1, concept_name_1,
                       1 - (embedding <=> q.vec) AS similarity,
                       ROW_NUMBER() OVER (
                           PARTITION BY concept_id_1
                           ORDER BY embedding <=> q.vec
                       ) AS rn
                FROM {schema}.{table}, q
                WHERE concept_id_1 != ALL(%s)
                ORDER BY embedding <=> q.vec
                LIMIT %s
            )
            SELECT concept_id_1, concept_name_1, similarity
            FROM ranked
            WHERE rn = 1
            ORDER BY similarity DESC
            LIMIT %s
        """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_reference"))
        with self.connection.cursor() as cur:
            cur.execute(query, (vector.tolist(), self._exclude_concept_ids, inner_limit, top_k))
            top_terms = cur.fetchall()
    else:
        query = sql.SQL("""
            WITH q AS (SELECT %s::vector AS vec),
                 ranked AS (
                SELECT concept_id_1, concept_name_1,
                       1 - (embedding <=> q.vec) AS similarity,
                       ROW_NUMBER() OVER (
                           PARTITION BY concept_id_1
                           ORDER BY embedding <=> q.vec
                       ) AS rn
                FROM {schema}.{table}, q
                ORDER BY embedding <=> q.vec
                LIMIT %s
            )
            SELECT concept_id_1, concept_name_1, similarity
            FROM ranked
            WHERE rn = 1
            ORDER BY similarity DESC
            LIMIT %s
        """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_reference"))
        with self.connection.cursor() as cur:
            cur.execute(query, (vector.tolist(), inner_limit, top_k))
            top_terms = cur.fetchall()

    if not top_terms:
        return ReferenceSearchResult([], cost)

    # Fetch all attribute rows for those concept_id_1 values
    ids = [r[0] for r in top_terms]
    attr_query = sql.SQL("""
        SELECT concept_id_1, concept_id_2, concept_code_2, concept_name_2, attribute_category
        FROM {schema}.{table}
        WHERE concept_id_1 = ANY(%s)
    """).format(schema=sql.Identifier(self.schema), table=sql.Identifier("snomed_reference"))
    with self.connection.cursor() as cur:
        cur.execute(attr_query, (ids,))
        attr_rows = cur.fetchall()

    # Group attributes by concept_id_1
    attrs_by_id: dict[int, list] = {}
    for concept_id_1, concept_id_2, concept_code_2, concept_name_2, attribute_category in attr_rows:
        attrs_by_id.setdefault(concept_id_1, []).append({
            "concept_id_2": concept_id_2,
            "concept_code_2": concept_code_2,
            "concept_name_2": concept_name_2,
            "attribute_category": attribute_category,
        })

    return ReferenceSearchResult([
        {
            "concept_id": concept_id_1,
            "concept_name": concept_name_1,
            "similarity": similarity,
            "attributes": attrs_by_id.get(concept_id_1, []),
        }
        for concept_id_1, concept_name_1, similarity in top_terms
    ], cost)