Source code for autorag.vectordb.chroma

from typing import List, Optional, Dict, Tuple

from chromadb import (
	EphemeralClient,
	PersistentClient,
	DEFAULT_TENANT,
	DEFAULT_DATABASE,
	CloudClient,
	AsyncHttpClient,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.types import IncludeEnum, QueryResult

from autorag.utils.util import apply_recursive
from autorag.vectordb.base import BaseVectorStore


[docs] class Chroma(BaseVectorStore): def __init__( self, embedding_model: str, collection_name: str, embedding_batch: int = 100, client_type: str = "persistent", similarity_metric: str = "cosine", path: str = None, host: str = "localhost", port: int = 8000, ssl: bool = False, headers: Optional[Dict[str, str]] = None, api_key: Optional[str] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ): super().__init__(embedding_model, similarity_metric, embedding_batch) if client_type == "ephemeral": self.client = EphemeralClient(tenant=tenant, database=database) elif client_type == "persistent": assert path is not None, "path must be provided for persistent client" self.client = PersistentClient(path=path, tenant=tenant, database=database) elif client_type == "http": self.client = AsyncHttpClient( host=host, port=port, ssl=ssl, headers=headers, tenant=tenant, database=database, ) elif client_type == "cloud": self.client = CloudClient( tenant=tenant, database=database, api_key=api_key, ) else: raise ValueError( f"client_type {client_type} is not supported\n" "supported client types are: ephemeral, persistent, http, cloud" ) self.collection = self.client.get_or_create_collection( name=collection_name, metadata={"hnsw:space": similarity_metric}, )
[docs] async def add(self, ids: List[str], texts: List[str]): texts = self.truncated_inputs(texts) text_embeddings = await self.embedding.aget_text_embedding_batch(texts) if isinstance(self.collection, AsyncCollection): await self.collection.add(ids=ids, embeddings=text_embeddings) else: self.collection.add(ids=ids, embeddings=text_embeddings)
[docs] async def fetch(self, ids: List[str]) -> List[List[float]]: if isinstance(self.collection, AsyncCollection): fetch_result = await self.collection.get( ids, include=[IncludeEnum.embeddings] ) else: fetch_result = self.collection.get(ids, include=[IncludeEnum.embeddings]) fetch_embeddings = fetch_result["embeddings"] return fetch_embeddings
[docs] async def is_exist(self, ids: List[str]) -> List[bool]: if isinstance(self.collection, AsyncCollection): fetched_result = await self.collection.get(ids, include=[]) else: fetched_result = self.collection.get(ids, include=[]) existed_ids = fetched_result["ids"] return list(map(lambda x: x in existed_ids, ids))
[docs] async def query( self, queries: List[str], top_k: int, **kwargs ) -> Tuple[List[List[str]], List[List[float]]]: queries = self.truncated_inputs(queries) query_embeddings: List[ List[float] ] = await self.embedding.aget_text_embedding_batch(queries) if isinstance(self.collection, AsyncCollection): query_result: QueryResult = await self.collection.query( query_embeddings=query_embeddings, n_results=top_k ) else: query_result: QueryResult = self.collection.query( query_embeddings=query_embeddings, n_results=top_k ) ids = query_result["ids"] scores = query_result["distances"] scores = apply_recursive(lambda x: 1 - x, scores) return ids, scores
[docs] async def delete(self, ids: List[str]): if isinstance(self.collection, AsyncCollection): await self.collection.delete(ids) else: self.collection.delete(ids)