import numpy as np
from typing import Tuple

from predictor import Predictor # This is the base class for all classifiers we will be implementing in this exercise.
# This class has two abstract methods that need to be implemented in all subclasses: fit and predict. That's all there is to it.

class KMeans_Classifier(Predictor):
    def __init__(self, num_clusters: int, num_iterations: int = 100):
        self.num_clusters = num_clusters
        self.num_iterations = num_iterations

    def fit(self, data: np.ndarray, labels: np.ndarray) -> None:
        """Fits the KMeans classifier to the data.
        We need to do two steps to fit our kmeans classifier, first we need to run the kmeans algorithm to find the centroids which are defining the clusters.
        Then we need to assign a label to each cluster based on the most frequent label of the data points in the cluster.
        This is not how kmeans is usually used, but for the purpose of this exercise we will use it as a classifier for fun.

        Args:
            data (np.ndarray): A Batch of vectors, where each vector represents a data point. Such that the shape is (n_samples, n_features).
            labels (np.ndarray): A numpy array of shape (n_samples,) where each element is an integer representing the label of the corresponding data point.
        """
        data_cluster_assignments, self.centroids = self._kmeans(data=data, num_clusters=self.num_clusters, num_iterations=self.num_iterations)
        self.cluster_label_assignments = self._assign_labels_to_clusters(data_cluster_assignments, labels)
    
    def _kmeans(self, data: np.ndarray, num_clusters: int, num_iterations: int) -> Tuple[np.ndarray, np.ndarray]:
        """Implementation of the naive k-means algorithm.

        Args:
            data (np.ndarray): A Batch of vectors, where each vector represents a data point. Such that the shape is (n_samples, n_features).
            n_clusters (int): Number of clusters to form.
            n_iterations (int): Number of iterations to run the algorithm.

        Returns:
            Tuple[np.ndarray, np.ndarray]: 
            A tuple containing the cluster assignments: a numpy array of shape (n_samples,) where each element is an integer representing the cluster index of the corresponding data point
            and the cluster centroids: a numpy array of shape (n_clusters, n_features) representing the final cluster centroids.
        """
        ### YOUR CODE GOES HERE ###

        #return cluster_ids, centroids
    
    def _assign_labels_to_clusters(self, cluster_ids: np.ndarray, labels: np.ndarray) -> np.ndarray:
        """Assigns a label to each cluster based on the most frequent label of the data points in the cluster.

        Args:
            cluster_ids (np.ndarray): A numpy array of shape (n_samples,) where each element is an integer representing the cluster index of the corresponding data point.
            labels (np.ndarray): A numpy array of shape (n_samples,) where each element is an integer representing the label of the corresponding data point.

        Returns:
            np.ndarray: A numpy array of shape (n_clusters,) where each element is an integer representing the label assigned to the corresponding cluster.
        """
        ### YOUR CODE GOES HERE ###
        
        # return assigned_labels

    def predict(self, data: np.ndarray) -> np.ndarray:
        """Predicts the labels of the data points based on the centroids and the labels assigned to the centroids.

        Args:
            data (np.ndarray): A Batch of vectors, where each vector represents a data point. Such that the shape is (n_samples, n_features).

        Returns:
            np.ndarray: A numpy array of shape (n_samples,) where each element is an integer representing the predicted label of the corresponding data point.
        """
        ### YOUR CODE GOES HERE ###
        
        #return predicted_labels