from typing import Union, Tuple
import numpy as np

from predictor import Predictor

class DecisionNode():
    def __init__(self, feature_index: int, threshold: float, condition_met_child: Union['DecisionNode', 'LeafNode'], condition_not_met_child: Union['DecisionNode', 'LeafNode']):
        self.feature_index = feature_index
        self.threshold = threshold
        self.condition_met_child = condition_met_child
        self.condition_not_met_child = condition_not_met_child

class LeafNode():
    def __init__(self, value):
        self.value = value

class DecisionTree(Predictor):
    def __init__(self, max_depth: int, min_samples: int) -> None:
        super().__init__()
        self.max_depth = max_depth
        self.min_samples = min_samples
        self.root = None

    def fit(self, data: np.ndarray, labels: np.ndarray) -> None:
        """Creates the decision tree by growing it from the root node.

        Args:
            data (np.ndarray): All the data to train the decision tree.
            labels (np.ndarray): All the labels corresponding to the data.
        """
        self.root = ... ### YOUR CODE GOES HERE ###

    def _grow_tree(self, data: np.ndarray, labels: np.ndarray, depth: int) -> Union[DecisionNode, LeafNode]:
        """Recursively grows the decision tree.

        Args:
            data (np.ndarray): Data of the current node.
            labels (np.ndarray): Labels of the data.
            depth (int): Current depth of this node.

        Returns:
            Union[DecisionNode, LeafNode]: The current node of the decision tree after growing all its child nodes.
        """
        
        ### YOUR CODE GOES HERE ###

    def _find_best_split(self, data: np.ndarray, labels: np.ndarray) -> Tuple[int, float]:
        """Finds the best split for the given data and labels.

        Args:
            data (np.ndarray): Current data to split.
            labels (np.ndarray): According labels for the data.

        Returns:
            Tuple[int, float]: The best feature index and threshold to split the data.
        """

        ### YOUR CODE GOES HERE ###

        #return best_feature, best_threshold
    
    def _gini_impurity(self, split_1_labels: np.ndarray, split_2_labels: np.ndarray) -> float:
        """Calculates the Gini impurity of a split. Following the formula:
        Gini = 1 - sum(p_i^2), where p_i is the proportion of samples that belong to class i in that split.
        Gini_impurity = (size_of_split_1 / total_size) * Gini(split_1) + (size_of_split_2 / total_size) * Gini(split_2)

        Args:
            split_1_labels (np.ndarray): Classifications of the first part of the split.
            split_2_labels (np.ndarray): Classifications of the second part of the split.

        Returns:
            float: The Gini impurity of the split.
        """

        ### YOUR CODE GOES HERE ###

        #return gini_impurity
    
    def predict(self, data: np.ndarray) -> np.ndarray:
        """Predicts the labels for the given data using the decision tree.

        Args:
            data (np.ndarray): The data to predict the labels for.

        Returns:
            np.ndarray: The predicted labels for the data.
        """

        ### YOUR CODE GOES HERE ###
    
    def _traverse_tree(self, x: np.ndarray, node: Union[DecisionNode, LeafNode]) -> int:
        """Traverses the decision tree to predict the label of the given data point.

        Args:
            x (np.ndarray): The single data point to predict the label for.
            node (Union[DecisionNode, LeafNode]): The current node in the decision tree.

        Returns:
            int: The predicted label of the data point.
        """

        ### YOUR CODE GOES HERE ###
