
import torch

def calculate_accuracy(model: torch.nn.Module, data: torch.Tensor, class_ids: torch.Tensor, batch_size: int = 256) -> float:
    """Calculate the accuracy of the model on the given data.

    Args:
        model (torch.nn.Module): The model to evaluate.
        data (torch.Tensor): The data to evaluate the model on.
        class_ids (torch.Tensor): The class ids of the data.
        batch_size (int, optional): size of the batches to use when calculating the accuracy. Defaults to 256.

    Returns:
        float: The accuracy of the model on the given data.
    """
    correct = 0
    total = 0
    with torch.no_grad():
        model.eval()
        for i in range(0, len(data), batch_size):
            batch = data[i:i+batch_size]
            batch_ids = class_ids[i:i+batch_size]
            outputs = model(batch)
            _, predicted = torch.max(outputs.data, 1)
            total += batch_ids.size(0)
            correct += (predicted == batch_ids).sum().item()
        model.train()

    return 100 * correct / total

def calculate_reconstruction_loss(autoencoder: torch.nn.Module, criterion: torch.nn.Module, data: torch.Tensor, batch_size: int = 256) -> float:
    """ Calculate the reconstruction loss of the autoencoder on the given data.

    Args:
        autoencoder (torch.nn.Module): autoencoder model
        criterion (torch.nn.Module): criterion used to calculate the loss
        data (torch.Tensor): data to calculate the loss on
        batch_size (int, optional): size of the batches to use when calculating the loss. Defaults to 256.

    Returns:
        float: reconstruction loss of the autoencoder on the given data
    """
    autoencoder.eval()
    total_loss = 0.0
    total_batches = 0

    with torch.no_grad():
        for i in range(0, len(data), batch_size):
            batch = data[i:i+batch_size]
            outputs = autoencoder(batch)
            loss = criterion(outputs, batch).item()
            total_loss += loss
            total_batches += 1

    autoencoder.train()

    return total_loss / total_batches