import os
import json
import logging
import faiss
import numpy as np

logger = logging.getLogger(__name__)

class VectorStore:
    def __init__(self, index_path: str, metadata_path: str, dimension: int = 384):
        self.index_path = index_path
        self.metadata_path = metadata_path
        self.dimension = dimension
        self.index = None
        self.metadata = {}
        self._load_index()
        self._load_metadata()

    def _load_index(self):
        """Load FAISS index from disk or initialize a new one."""
        if os.path.exists(self.index_path):
            logger.info("Loading FAISS index from disk...")
            self.index = faiss.read_index(self.index_path)
            logger.info(f"FAISS index loaded with {self.index.ntotal} records")
        else:
            logger.warning("FAISS index file not found. Initializing new index.")
            self.index = faiss.IndexFlatL2(self.dimension)

    def _load_metadata(self):
        """Load metadata from disk."""
        if os.path.exists(self.metadata_path):
            try:
                with open(self.metadata_path, "r") as f:
                    self.metadata = json.load(f)
                logger.info(f"Metadata loaded successfully with {len(self.metadata)} entries")
            except Exception as e:
                logger.error(f"Error loading metadata: {str(e)}")
                self.metadata = {}

    def is_index_initialized(self):
        """Check if the FAISS index is initialized and contains records."""
        return self.index is not None and self.index.ntotal > 0

    def save_index(self):
        """Save the FAISS index and metadata to disk."""
        faiss.write_index(self.index, self.index_path)
        with open(self.metadata_path, "w") as f:
            json.dump(self.metadata, f)
        logger.info(f"Saved FAISS index to {self.index_path}")
        logger.info(f"Saved metadata to {self.metadata_path}")

    def get_next_file_id(self):
        """Generate a unique file ID based on existing metadata."""
        if not self.metadata:
            return 1
        return max(map(int, self.metadata.keys()), default=0) + 1  # Get max ID and increment

    def insert_records(self, embeddings, metadata, file_name):
        """Insert new records into FAISS index if they are not already indexed."""
        if file_name in self.metadata:
            logger.info(f"Skipping file {file_name} as it's already indexed.")
            return

        try:
            embeddings = np.array(embeddings, dtype=np.float32)
            if self.index is None:
                raise ValueError("Index not initialized")

            file_id = self.get_next_file_id()  # Generate a new file ID

            self.index.add(embeddings)
            self.metadata[str(file_id)] = metadata
            self.save_index()

            logger.info(f"Inserted {len(embeddings)} records for file ID {file_id}")
            logger.info(f"Total records in FAISS index: {self.index.ntotal}")
        except Exception as e:
            logger.error(f"Record insertion failed: {str(e)}")
            raise

    def search_index(self, query_embedding, top_k=5):
        """Search the FAISS index for the most relevant results."""
        if not self.is_index_initialized():
            logger.error("Search failed: FAISS index is empty.")
            return []

        try:
            query_embedding = np.array(query_embedding, dtype=np.float32).reshape(1, -1)
            distances, indices = self.index.search(query_embedding, top_k)

            results = []
            for idx, dist in zip(indices[0], distances[0]):
                if idx == -1:
                    continue

                # Convert NumPy types to native Python types
                idx = int(idx)  
                dist = float(dist)  

                metadata_entry = self.metadata.get(str(idx), {})

                results.append({
                    "index": idx,
                    "distance": dist,
                    "metadata": metadata_entry
                })

            logger.info(f"Found {len(results)} search results.")
            return results

        except Exception as e:
            logger.error(f"Search failed: {str(e)}")
            return []

