import torch

from plotting import plot_accuracy, plot_loss
from metric import calculate_accuracy, calculate_reconstruction_loss


class ClassificationTrainer():
    """Trainer class to train a neural network for classification tasks.
    """

    def __init__(self, network: torch.nn.Module,
                        optimizer: torch.nn.Module,
                        criterion: torch.nn.Module,
                        batch_size: int) -> None:
        self.network = network
        self.optimizer = optimizer
        self.criterion = criterion
        self.batch_size = batch_size


        self.train_accuracy = []
        self.validation_accuracy = []

    def train(self, 
              epochs: int, 
                x_train: torch.Tensor,
                y_train: torch.Tensor,
                x_valid: torch.Tensor,
                y_valid: torch.Tensor) -> None:
        """Train the network for the given number of epochs.

        Args:
            epochs (int): The number of epochs to train the network.
            x_train (torch.Tensor): Data to train the network on.
            y_train (torch.Tensor): Labels of the training data.
            x_valid (torch.Tensor): Data to validate the network on.
            y_valid (torch.Tensor): Labels of the validation data.
        """
        
        # create a loop that trains the network for the given number of epochs
        for epoch in range(epochs): # loop over the dataset multiple times
            correct_train = 0
            total_train = 0

            # calculate the accuracy of the network on the training and validation data before training
            if len(self.train_accuracy) == 0:
                self.train_accuracy.append(calculate_accuracy(self.network, x_train, y_train))
                self.validation_accuracy.append(calculate_accuracy(self.network, x_valid, y_valid))

            # TODO: create a loop that does as many optimizations steps as necessary 
            # to go through the whole dataset once cosindering the batch size
            for i in range(0, len(x_train), self.batch_size):  # assuming batch size of 4
                # TODO get the according batch of inputs and labels
                inputs = ...
                labels = ...


                #TODO: implement the training steps by:

                    # predict the class of the input data
                outputs = ...
                    # calculate the loss
                    # zero the parameter gradients
                    # backpropagate the loss
                    # update the weights

                # calculate training accuracy
                _, predicted_train = torch.max(outputs.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted_train == labels).sum().item()

            # plot statistics
            self.train_accuracy.append(100 * correct_train / total_train)
            self.validation_accuracy.append(calculate_accuracy(self.network, x_valid, y_valid))

            plot_accuracy(self.train_accuracy, self.validation_accuracy)

class AutoencoderTrainer():
    def __init__(self, autoencoder: torch.nn.Module,
                        optimizer: torch.nn.Module,
                        criterion: torch.nn.Module,
                        batch_size: int) -> None:
        self.autoencoder = autoencoder
        self.optimizer = optimizer
        self.criterion = criterion
        self.batch_size = batch_size

        self.train_loss = []
        self.validation_loss = []

    def train(self, 
              epochs: int, 
              x_train: torch.Tensor,
              x_valid: torch.Tensor) -> None:
        """Train the autoencoder for the given number of epochs.

        Args:
            epochs (int): The number of epochs to train the autoencoder.
            x_train (torch.Tensor): Data to train the autoencoder on.
            x_valid (torch.Tensor): Data to validate the autoencoder on.
        """

        # create a loop that trains the autoencoder for the given number of epochs
        for epoch in range(epochs):
            running_loss = 0.0

            #calculate the reconstruction loss of the autoencoder on the training and validation data before training
            if len(self.train_loss) == 0:
                self.train_loss.append(calculate_reconstruction_loss(self.autoencoder, self.criterion, x_train))
                self.validation_loss.append(calculate_reconstruction_loss(self.autoencoder, self.criterion, x_valid))

            for i in range(0, len(x_train), self.batch_size):  # assuming batch size of 4
                #TODO get the according batch of inputs

                #TODO: implement the training steps by:

                    # create the reconstruction
                    # calculate the reconstruction loss to the original input
                loss = ...
                    # zero the parameter gradients
                    # backpropagate the loss
                    # update the weights

                # print statistics
                running_loss += loss.item()

            # plot statistics
            self.train_loss.append(running_loss / len(x_train))
            self.validation_loss.append(calculate_reconstruction_loss(self.autoencoder, self.criterion, x_valid))

            plot_loss(self.train_loss, self.validation_loss)
