import os
import requests
import gzip
import numpy as np

def load_mnist():
    base_url = "http://yann.lecun.com/exdb/mnist/"
    file_names = ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", 
                  "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]

    for file_name in file_names:
        if not os.path.exists(file_name):
            response = requests.get(base_url + file_name)
            if response.status_code != 200:
                raise Exception(f"Failed to download {file_name} with status code {response.status_code}")
            with open(file_name, "wb") as file:
                file.write(response.content)

    def load_images(file_name):
        with gzip.open(file_name, "rb") as file:
            data = np.frombuffer(file.read(), np.uint8, offset=16)
            return data.reshape(-1, 28 * 28)

    def load_labels(file_name):
        with gzip.open(file_name, "rb") as file:
            return np.frombuffer(file.read(), np.uint8, offset=8)

    train_images = load_images("train-images-idx3-ubyte.gz")
    train_labels = load_labels("train-labels-idx1-ubyte.gz")
    test_images = load_images("t10k-images-idx3-ubyte.gz")
    test_labels = load_labels("t10k-labels-idx1-ubyte.gz")

    return (train_images, train_labels), (test_images, test_labels)