from typing import List, Optional, Dict, Tuple
from chromadb import (
EphemeralClient,
PersistentClient,
DEFAULT_TENANT,
DEFAULT_DATABASE,
CloudClient,
AsyncHttpClient,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.types import IncludeEnum, QueryResult
from autorag.utils.util import apply_recursive
from autorag.vectordb.base import BaseVectorStore
[docs]
class Chroma(BaseVectorStore):
def __init__(
self,
embedding_model: str,
collection_name: str,
embedding_batch: int = 100,
client_type: str = "persistent",
similarity_metric: str = "cosine",
path: str = None,
host: str = "localhost",
port: int = 8000,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
api_key: Optional[str] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
if client_type == "ephemeral":
self.client = EphemeralClient(tenant=tenant, database=database)
elif client_type == "persistent":
assert path is not None, "path must be provided for persistent client"
self.client = PersistentClient(path=path, tenant=tenant, database=database)
elif client_type == "http":
self.client = AsyncHttpClient(
host=host,
port=port,
ssl=ssl,
headers=headers,
tenant=tenant,
database=database,
)
elif client_type == "cloud":
self.client = CloudClient(
tenant=tenant,
database=database,
api_key=api_key,
)
else:
raise ValueError(
f"client_type {client_type} is not supported\n"
"supported client types are: ephemeral, persistent, http, cloud"
)
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": similarity_metric},
)
[docs]
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings = await self.embedding.aget_text_embedding_batch(texts)
if isinstance(self.collection, AsyncCollection):
await self.collection.add(ids=ids, embeddings=text_embeddings)
else:
self.collection.add(ids=ids, embeddings=text_embeddings)
[docs]
async def fetch(self, ids: List[str]) -> List[List[float]]:
if isinstance(self.collection, AsyncCollection):
fetch_result = await self.collection.get(
ids, include=[IncludeEnum.embeddings]
)
else:
fetch_result = self.collection.get(ids, include=[IncludeEnum.embeddings])
fetch_embeddings = fetch_result["embeddings"]
return fetch_embeddings
[docs]
async def is_exist(self, ids: List[str]) -> List[bool]:
if isinstance(self.collection, AsyncCollection):
fetched_result = await self.collection.get(ids, include=[])
else:
fetched_result = self.collection.get(ids, include=[])
existed_ids = fetched_result["ids"]
return list(map(lambda x: x in existed_ids, ids))
[docs]
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
if isinstance(self.collection, AsyncCollection):
query_result: QueryResult = await self.collection.query(
query_embeddings=query_embeddings, n_results=top_k
)
else:
query_result: QueryResult = self.collection.query(
query_embeddings=query_embeddings, n_results=top_k
)
ids = query_result["ids"]
scores = query_result["distances"]
scores = apply_recursive(lambda x: 1 - x, scores)
return ids, scores
[docs]
async def delete(self, ids: List[str]):
if isinstance(self.collection, AsyncCollection):
await self.collection.delete(ids)
else:
self.collection.delete(ids)