Source code for autorag.vectordb.qdrant

import logging

from qdrant_client import QdrantClient
from qdrant_client.models import (
	Distance,
	VectorParams,
	PointStruct,
	PointIdsList,
	HasIdCondition,
	Filter,
	SearchRequest,
)

from typing import List, Tuple

from autorag.vectordb import BaseVectorStore

logger = logging.getLogger("AutoRAG")


[docs] class Qdrant(BaseVectorStore): def __init__( self, embedding_model: str, collection_name: str, embedding_batch: int = 100, similarity_metric: str = "cosine", client_type: str = "docker", url: str = "http://localhost:6333", host: str = "", api_key: str = "", dimension: int = 1536, ingest_batch: int = 64, parallel: int = 1, max_retries: int = 3, ): super().__init__(embedding_model, similarity_metric, embedding_batch) self.collection_name = collection_name self.ingest_batch = ingest_batch self.parallel = parallel self.max_retries = max_retries if similarity_metric == "cosine": distance = Distance.COSINE elif similarity_metric == "ip": distance = Distance.DOT elif similarity_metric == "l2": distance = Distance.EUCLID else: raise ValueError( f"similarity_metric {similarity_metric} is not supported\n" "supported similarity metrics are: cosine, ip, l2" ) if client_type == "docker": self.client = QdrantClient( url=url, ) elif client_type == "cloud": self.client = QdrantClient( host=host, api_key=api_key, ) else: raise ValueError( f"client_type {client_type} is not supported\n" "supported client types are: docker, cloud" ) if not self.client.collection_exists(collection_name): self.client.create_collection( collection_name, vectors_config=VectorParams( size=dimension, distance=distance, ), ) self.collection = self.client.get_collection(collection_name)
[docs] async def add(self, ids: List[str], texts: List[str]): texts = self.truncated_inputs(texts) text_embeddings = await self.embedding.aget_text_embedding_batch(texts) points = list( map(lambda x: PointStruct(id=x[0], vector=x[1]), zip(ids, text_embeddings)) ) self.client.upload_points( collection_name=self.collection_name, points=points, batch_size=self.ingest_batch, parallel=self.parallel, max_retries=self.max_retries, wait=True, )
[docs] async def fetch(self, ids: List[str]) -> List[List[float]]: # Fetch vectors by IDs fetched_results = self.client.retrieve( collection_name=self.collection_name, ids=ids, with_vectors=True, ) return list(map(lambda x: x.vector, fetched_results))
[docs] async def is_exist(self, ids: List[str]) -> List[bool]: existed_result = self.client.scroll( collection_name=self.collection_name, scroll_filter=Filter( must=[ HasIdCondition(has_id=ids), ], ), ) # existed_result is tuple. So we use existed_result[0] to get list of Record existed_ids = list(map(lambda x: x.id, existed_result[0])) return list(map(lambda x: x in existed_ids, ids))
[docs] async def query( self, queries: List[str], top_k: int, **kwargs ) -> Tuple[List[List[str]], List[List[float]]]: queries = self.truncated_inputs(queries) query_embeddings: List[ List[float] ] = await self.embedding.aget_text_embedding_batch(queries) search_queries = list( map( lambda x: SearchRequest(vector=x, limit=top_k, with_vector=True), query_embeddings, ) ) search_result = self.client.search_batch( collection_name=self.collection_name, requests=search_queries ) # Extract IDs and distances ids = [[str(hit.id) for hit in result] for result in search_result] scores = [[hit.score for hit in result] for result in search_result] return ids, scores
[docs] async def delete(self, ids: List[str]): self.client.delete( collection_name=self.collection_name, points_selector=PointIdsList(points=ids), )
[docs] def delete_collection(self): # Delete the collection self.client.delete_collection(self.collection_name)