import logging
from pinecone.grpc import PineconeGRPC as Pinecone_client
from pinecone import ServerlessSpec
from typing import List, Optional, Tuple
from autorag.utils.util import make_batch, apply_recursive
from autorag.vectordb import BaseVectorStore
logger = logging.getLogger("AutoRAG")
[docs]
class Pinecone(BaseVectorStore):
def __init__(
self,
embedding_model: str,
index_name: str,
embedding_batch: int = 100,
dimension: int = 1536,
similarity_metric: str = "cosine", # "cosine", "dotproduct", "euclidean"
cloud: Optional[str] = "aws",
region: Optional[str] = "us-east-1",
api_key: Optional[str] = None,
deletion_protection: Optional[str] = "disabled", # "enabled" or "disabled"
namespace: Optional[str] = "default",
ingest_batch: int = 200,
):
super().__init__(embedding_model, similarity_metric, embedding_batch)
self.index_name = index_name
self.namespace = namespace
self.ingest_batch = ingest_batch
self.client = Pinecone_client(api_key=api_key)
if similarity_metric == "ip":
similarity_metric = "dotproduct"
elif similarity_metric == "l2":
similarity_metric = "euclidean"
if not self.client.has_index(index_name):
self.client.create_index(
name=index_name,
dimension=dimension,
metric=similarity_metric,
spec=ServerlessSpec(
cloud=cloud,
region=region,
),
deletion_protection=deletion_protection,
)
self.index = self.client.Index(index_name)
[docs]
async def add(self, ids: List[str], texts: List[str]):
texts = self.truncated_inputs(texts)
text_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(texts)
vector_tuples = list(zip(ids, text_embeddings))
batch_vectors = make_batch(vector_tuples, self.ingest_batch)
async_res = [
self.index.upsert(
vectors=batch_vector_tuples,
namespace=self.namespace,
async_req=True,
)
for batch_vector_tuples in batch_vectors
]
# Wait for the async requests to finish
[async_result.result() for async_result in async_res]
[docs]
async def fetch(self, ids: List[str]) -> List[List[float]]:
results = self.index.fetch(ids=ids, namespace=self.namespace)
id_vector_dict = {
str(key): val["values"] for key, val in results["vectors"].items()
}
result = [id_vector_dict[_id] for _id in ids]
return result
[docs]
async def is_exist(self, ids: List[str]) -> List[bool]:
fetched_result = self.index.fetch(ids=ids, namespace=self.namespace)
existed_ids = list(map(str, fetched_result.get("vectors", {}).keys()))
return list(map(lambda x: x in existed_ids, ids))
[docs]
async def query(
self, queries: List[str], top_k: int, **kwargs
) -> Tuple[List[List[str]], List[List[float]]]:
queries = self.truncated_inputs(queries)
query_embeddings: List[
List[float]
] = await self.embedding.aget_text_embedding_batch(queries)
ids, scores = [], []
for query_embedding in query_embeddings:
response = self.index.query(
vector=query_embedding,
top_k=top_k,
include_values=True,
namespace=self.namespace,
)
ids.append([o.id for o in response.matches])
scores.append([o.score for o in response.matches])
if self.similarity_metric in ["l2"]:
scores = apply_recursive(lambda x: -x, scores)
return ids, scores
[docs]
async def delete(self, ids: List[str]):
# Delete entries by IDs
self.index.delete(ids=ids, namespace=self.namespace)
[docs]
def delete_index(self):
# Delete the index
self.client.delete_index(self.index_name)