import matplotlib.pyplot as plt
from IPython.display import clear_output
import numpy as np
import torch
from typing import List

def plot_accuracy(training_accuracies: List[float], validation_accuracies: List[float]) -> None:
    """Plot the training and validation accuracy.

    Args:
        training_accuracies (List[float]): The training accuracies per epoch.
        validation_accuracies (List[float]): The validation accuracies per epoch.
    """

    # Clear the current figure and cell output
    plt.clf()
    clear_output(wait=True)

    # Plot training and validation accuracy
    plt.plot(training_accuracies, label='Training Accuracy')
    plt.plot(validation_accuracies, label='Validation Accuracy')

    # Add title and labels
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')

    plt.xlim(0, len(training_accuracies))
    plt.ylim(0, 100)
    plt.grid(True)

    # Add legend
    plt.legend(loc='upper left')

    # Show the plot
    plt.show()

def plot_loss(training_losses: List[float], validation_losses: List[float]):
    """Plot the training and validation loss.

    Args:
        training_losses (List[float]): The training losses per epoch.
        validation_losses (List[float]): The validation losses per epoch.
    """

    # Clear the current figure and cell output
    plt.clf()
    clear_output(wait=True)

    # Plot training and validation loss
    plt.plot(training_losses, label='Training Loss')
    plt.plot(validation_losses, label='Validation Loss')

    # Add title and labels
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.xlim(0, len(training_losses))
    plt.ylim(0, max(max(training_losses), max(validation_losses)))
    plt.grid(True)

    # Add legend
    plt.legend(loc='upper right')

    # Show the plot
    plt.show()

def tensor_to_image(tensor: torch.Tensor) -> np.ndarray:
    """Convert a PyTorch tensor to a NumPy array to be displayed as an image.

    Args:
        tensor (torch.Tensor): The PyTorch tensor to convert, in the format (channels, height, width) with values in the range [0, 1].

    Returns:
        np.ndarray: The NumPy array representing the image, in the format (height, width, channels) with values in the range [0, 255] as intergers.
    """
    return (np.transpose(tensor.detach().cpu().numpy().squeeze(), (1, 2, 0)) * 255.).astype(np.uint8)

def plot_reconstructions(autoencoder: torch.nn.Module, 
                         train_data: torch.nn.Module, 
                         valid_data:torch.nn.Module, 
                         number_samples:int=10) -> None:
    """Plot the original and reconstructed images for the training and validation data.

    Args:
        autoencoder (torch.nn.Module): The autoencoder model.
        train_data (torch.nn.Module): Subset of the training data to plot.
        valid_data (torch.nn.Module): Subset of the validation data to plot.
        number_samples (int, optional): The number of samples to plot. Defaults to 10.
    """

    # Set the model to evaluation mode
    autoencoder.eval()

    # Plot training images and their reconstructions
    plt.figure(figsize=(20, 4))
    plt.suptitle('Training Data', fontsize=20)
    for i in range(number_samples):
        # Display original
        ax = plt.subplot(2, number_samples, i + 1)
        plt.imshow(tensor_to_image(train_data[i]))
        plt.title("original")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # Display reconstruction
        ax = plt.subplot(2, number_samples, i + 1 + number_samples)
        with torch.no_grad():
            reconstruction = autoencoder(train_data[i].unsqueeze(0))
        plt.imshow(tensor_to_image(reconstruction[0]))
        plt.title("reconstruction")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

    # Plot validation images and their reconstructions
    plt.figure(figsize=(20, 4))
    plt.suptitle('Validation Data', fontsize=20)
    for i in range(number_samples):
        # Display original
        ax = plt.subplot(2, number_samples, i + 1)
        plt.imshow(tensor_to_image(valid_data[i]))
        plt.title("original")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # Display reconstruction
        ax = plt.subplot(2, number_samples, i + 1 + number_samples)
        with torch.no_grad():
            reconstruction = autoencoder(valid_data[i].unsqueeze(0))
        plt.imshow(tensor_to_image(reconstruction[0]))
        plt.title("reconstruction")
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

    # Set the model back to training mode
    autoencoder.train()