#!/usr/bin/env python
# coding: utf-8

# # Binary Search and AVL trees
# 
# In this exercise we will implement commonly used methods of binary search trees (AVL-trees).
# 
# The implementation of `insert` and `delete` is optional. As an incentive, _should_ you decide to implement it to hone your understanding of AVL-trees, a correct implementation will be automatically awarded with full points (i.e. 2) for this exercise. Why? To appropriately implement insert and delete, you will require some of the other functions and the pen-and-paper exercises the become a subset of what you are required for these two functions.


from typing import Optional, Union

# set for _settattr_ checking
NODE_ATTRS = {"left", "right", "value", "parent"}
from utils import _build_tree_string

class Node:
    """A simple binary tree node.


    Params:
        value: value associated with `Node`
        left: left child `Node` of node
        right: right child `Node` of node
        parent: pointer to parent node to be able to walk up the tree
    Methods:
        __setattr__:
            Allows you to set the `value`, `left` or `right` attributes of the `Node`

            node = Node(3)
            node.value = 5

        __str__:
            You can print a clean representation of the tree for debugging purposes with `print(node)`.

            >> root = Node(10)
            >> root.left = Node(1)
            >> root.right = Node(3)
            >> root.right.right = Node(30)
            >> print(root)
            >>   10
            >>  /  \
            >> 1    3
            >>       \
            >>        30
    """

    def __init__(
        self,
        value: int,
        left: Optional["Node"] = None,
        right: Optional["Node"] = None,
        parent: Optional["Node"] = None,
    ) -> None:
        """Initialize binary tree node.

        Args:
            value: value associated with `Node`
            left: left child `Node` of node
            right: right child `Node` of node
            parent: pointer to parent `Node`

        Returns:
            None: instantiated node
        """
        self.value = value
        self.left = left
        self.right = right
        self.parent = parent

    def __setattr__(self, attr: str, obj: Union[int, "Node"]) -> None:
        """``__setattr__`` with type checking."""
        assert (
            attr in NODE_ATTRS
        ), f"{attr} is not a valid Node attribute ({NODE_ATTRS})"
        object.__setattr__(self, attr, obj)

    def __str__(self) -> str:
        lines = _build_tree_string(self, 0, False, "-")[0]
        return "\n" + "\n".join((line.rstrip() for line in lines))


# We provide a utility function to print trees into a string representation that may help you debugging your implementations.
# Note: The below cell only works if `utils.py` is in the same folder as the Ipython Notebook.


node = Node(3)
node.left = Node(4)
node.right = Node(5)
print(node)


# Get the minimum of the binary search sub-tree spanned by `Node` root.



def get_min(root: Node) -> int:
    """Get the minimum of the binary search tree spanned by `root`.

    Args:
        root: root `Node` of the tree.

    Returns:
        int: the minimum value of the binary search tree
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# Get the maximum of the binary search sub-tree spanned by `Node` root.



def get_max(root: Node) -> int:
    """Get the maximum of the binary search tree spanned by `root`.

    Args:
        root: root `Node` of the tree.

    Returns:
        int: the maximum value of the binary search tree
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# Search for the value in the binary sub-tree spanned by `Node` root.



def search(root: Node, value: int) -> Optional[Node]:
    """Returns the `Node` of `value` if found in tree.

    Assume:
        - the sub-tree of `root` spans a binary search tree

    Args:
        root: root `Node` of the tree.
        value: Integer to `search` for in the tree

    Returns:
        Optional[Node]: return the `Node` of `value` if `value` found in binary search tree of `root`
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# ### Balance in binary search trees (AVL trees)
# 
# The constraint is generally applied recursively to every subtree. That is, the tree is only balanced if:
# 
#     The left and right subtrees' heights differ by at most one, AND
#     The left subtree is balanced, AND
#     The right subtree is balanced
# 
# According to this, the next tree is balanced:
# 
# <pre>
#       A
#     /   \
#    B     C  
#   /     / \  
#  D     E   F  
#       /  
#      G  
# </pre>
# 
# The next one is not balanced because the subtrees of C differ by 2 in their height:
# 
# <pre>
#      A
#    /   \
#   B     C   <-- difference = 2
#  /     /
# D     E  
#      /  
#     G  
# </pre>
# 
# That said, the specific constraint of the first point depends on the type of tree. The one listed above is the typical for AVL trees.\
# (Creds to this nice outline: https://stackoverflow.com/a/14712245)
# 
# Test whether the sub-tree spanned by `Node` `root` is balanced.
#  
# Note: should you implement insertion and deletion, think about how you will require the height in other functions and potentially implement another helper function that implements the core logic of `is_balanced` working primarily with tree heights)



def is_balanced(root: Optional[Node]) -> bool:
    """Returns whether the tree spanned by the sub-tree of `root` is balanced.

    As per the lecture, a tree is balanced if the balance factor is valid.
    (cf. slide 13)

    Args:
        root: the root node of the binary tree.

    Returns:
        bool: whether or not the tree is balanced.
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# Test whether the sub-tree spanned by `Node` `root` is balanced.



def is_binary_search_tree(root: Optional[Node]) -> bool:
    """Check if the binary tree is a BST (binary search tree).

    Args:
        root: Root node of the binary tree.
    Returns:
        bool: `True` if the binary tree is a binary search tree, `False` otherwise.
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# Implement insertion for `AVL-trees`. The implementation is analogous to the modular pseudo-code of the lecture.



def insert(root: Optional[Node], value: int) -> None:
    """Insert `value` into `root`, maintaining AVL-tree properties.

    This requires you to implement the appropriate rotations modularly in separate functions.

    Note:
        - You will have to create a `Node` from `value`

    Args:
        root: Root node of the binary tree.
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# Implement deletion for `AVL-trees`. The implementation is analogous to the modular pseudo-code of the lecture.



def delete(root: Optional[Node], value: int) -> None:
    """Delete value and maintain AVL-tree properties.

    Requires:
        - Identifying the node with `value` in the tree spanned by `Node` root
        - Implementing the appropriate rotations (cf. insert)
    """
    ################################
    #                              #
    #   Your implementation here   #
    #                              #
    ################################
    pass


# Finally, test your implementations with the below template.



class TestAVLTree(unittest.TestCase):
    def test_get_min(self):
        node = Node(12)
        self.assertEqual(get_min(node), 12)
        node.left = Node(6)
        self.assertEqual(get_min(node), 6)
        node.right = Node(19)
        self.assertEqual(get_min(node), 6)
        node.left.left = Node(3)
        self.assertEqual(get_min(node), 3)
        node.left.right = Node(9)
        self.assertEqual(get_min(node), 3)
        node.right.left = Node(14)
        self.assertEqual(get_min(node), 3)
        node.right.right = Node(24)
        self.assertEqual(get_min(node), 3)
        node.left.left.left = Node(2)
        self.assertEqual(get_min(node), 2)
        node.left.left.right = Node(5)
        self.assertEqual(get_min(node), 2)
        node.left.right.left = Node(8)
        self.assertEqual(get_min(node), 2)
        node.left.right.right = Node(11)
        self.assertEqual(get_min(node), 2)
        node.right.left.left = Node(13)
        self.assertEqual(get_min(node), 2)
        node.right.left.right = Node(17)
        self.assertEqual(get_min(node), 2)
        node.right.right.left = Node(21)
        self.assertEqual(get_min(node), 2)
        node.right.right.right = Node(27)
        self.assertEqual(get_min(node), 2)
        self.assertEqual(get_min(node.right), 13)
        self.assertEqual(get_min(node.right.right), 21)

    def test_get_max(self):
        node = Node(12)
        self.assertEqual(get_max(node), 12)
        node.left = Node(6)
        self.assertEqual(get_max(node), 12)
        node.right = Node(19)
        self.assertEqual(get_max(node), 19)
        node.left.left = Node(3)
        self.assertEqual(get_max(node), 19)
        node.left.right = Node(9)
        self.assertEqual(get_max(node), 19)
        node.right.left = Node(14)
        self.assertEqual(get_max(node), 19)
        node.right.right = Node(24)
        self.assertEqual(get_max(node), 24)
        node.left.left.left = Node(2)
        self.assertEqual(get_max(node), 24)
        node.left.left.right = Node(5)
        self.assertEqual(get_max(node), 24)
        node.left.right.left = Node(8)
        self.assertEqual(get_max(node), 24)
        node.left.right.right = Node(11)
        self.assertEqual(get_max(node), 24)
        node.right.left.left = Node(13)
        self.assertEqual(get_max(node), 24)
        node.right.left.right = Node(17)
        self.assertEqual(get_max(node), 24)
        node.right.right.left = Node(21)
        self.assertEqual(get_max(node), 24)
        node.right.right.right = Node(27)
        self.assertEqual(get_max(node), 27)
        self.assertEqual(get_max(node.left), 11)

    def test_is_balanced(self):
        root = Node(1)
        self.assertEqual(is_balanced(root), True)
        root.left = Node(2)
        self.assertEqual(is_balanced(root), True)
        root.right = Node(3)
        self.assertEqual(is_balanced(root), True)
        root.left.left = Node(4)
        self.assertEqual(is_balanced(root), True)
        root.right.left = Node(5)
        self.assertEqual(is_balanced(root), True)
        root.right.left.left = Node(6)
        self.assertEqual(is_balanced(root), False)
        root.left.left.left = Node(7)
        self.assertEqual(is_balanced(root), False)

    def test_is_binary_search_tree(self):
        node = Node(12)
        node.left = Node(6)
        node.right = Node(19)
        node.left.left = Node(3)
        node.left.right = Node(9)
        node.right.left = Node(14)
        node.right.right = Node(24)
        node.left.left.left = Node(2)
        node.left.left.right = Node(5)
        node.left.right.left = Node(8)
        node.left.right.right = Node(11)
        node.right.left.left = Node(13)
        node.right.left.right = Node(17)
        node.right.right.left = Node(21)
        node.right.right.right = Node(27)
        self.assertEqual(True, is_binary_search_tree(node))
        node = Node(12)
        node.left = Node(6)
        node.right = Node(19)
        node.left.left = Node(3)
        node.left.right = Node(9)
        node.right.left = Node(14)
        node.right.right = Node(24)
        node.left.left.left = Node(82)
        node.left.left.right = Node(5)
        node.left.right.left = Node(8)
        node.left.right.right = Node(11)
        node.right.left.left = Node(13)
        node.right.left.right = Node(17)
        node.right.right.left = Node(21)
        node.right.right.right = Node(27)
        self.assertEqual(False, is_binary_search_tree(node))
        self.assertEqual(False, is_binary_search_tree(node.left))
        self.assertEqual(True, is_binary_search_tree(node.right))
        node = Node(12)
        node.left = Node(6)
        node.right = Node(19)
        node.left.left = Node(3)
        node.left.right = Node(9)
        node.right.left = Node(9999)
        node.right.right = Node(24)
        node.left.left.left = Node(2)
        node.left.left.right = Node(5)
        node.left.right.left = Node(8)
        node.left.right.right = Node(11)
        node.right.left.left = Node(13)
        node.right.left.right = Node(17)
        node.right.right.left = Node(21)
        node.right.right.right = Node(27)
        self.assertEqual(True, is_binary_search_tree(node.left))
        self.assertEqual(False, is_binary_search_tree(node.right))

    # As a cautionary note: the tests for insert and deletion are quite likely not fully exhaustive.
    def test_insert(self):
        node = Node(9)
        insert(node, 15)
        self.assertEqual(node.right.value, 15)
        insert(node, 20)
        self.assertEqual(node.value, 15)
        self.assertEqual(node.left.value, 9)
        self.assertEqual(node.right.value, 20)
        insert(node, 8)
        self.assertEqual(node.left.left.value, 8)
        insert(node, 7)
        self.assertEqual(node.left.value, 8)
        self.assertEqual(node.left.left.value, 7)
        self.assertEqual(node.left.right.value, 9)
        insert(node, 13)
        self.assertEqual(node.value, 9)
        self.assertEqual(node.right.value, 15)
        self.assertEqual(node.right.left.value, 13)
        self.assertEqual(node.right.right.value, 20)
        insert(node, 10)
        self.assertEqual(node.right.left.left.value, 10)

        node = Node(3)
        insert(node, 5)
        self.assertEqual(node.right.value, 5)
        insert(node, 10)
        self.assertEqual(node.value, 5)
        self.assertEqual(node.left.value, 3)
        self.assertEqual(node.right.value, 10)
        insert(node, 15)
        self.assertEqual(node.right.right.value, 15)
        insert(node, 23)
        self.assertEqual(node.right.value, 15)
        self.assertEqual(node.right.left.value, 10)
        self.assertEqual(node.right.right.value, 23)
        insert(node, 1)
        self.assertEqual(node.left.left.value, 1)
        insert(node, 4)
        self.assertEqual(node.left.right.value, 4)
        insert(node, 100)
        self.assertEqual(node.right.right.right.value, 100)

    def test_delete(self):
        node = Node(9)
        node.left = Node(8)
        node.left.left = Node(7)
        node.right = Node(15)
        node.right.left = Node(13)
        node.right.right = Node(20)
        node.right.left.left = Node(10)
        delete(node, 9)
        self.assertEqual(node.value, 13)
        self.assertEqual(node.left.value, 8)
        self.assertEqual(node.left.left.value, 7)
        self.assertEqual(node.left.right.value, 10)
        self.assertEqual(node.right.value, 15)
        self.assertEqual(node.right.right.value, 20)
        delete(node, 13)
        self.assertEqual(node.value, 10)
        self.assertEqual(node.left.value, 8)
        self.assertEqual(node.left.left.value, 7)
        self.assertEqual(node.right.value, 15)
        self.assertEqual(node.right.right.value, 20)

        node = Node(5)
        node.left = Node(3)
        node.left.left = Node(1)
        node.left.right = Node(4)
        node.right = Node(15)
        node.right.left = Node(10)
        node.right.right = Node(23)
        delete(node, 5)
        self.assertEqual(node.value, 4)
        self.assertEqual(node.left.value, 3)
        self.assertEqual(node.left.left.value, 1)
        self.assertEqual(node.right.value, 15)
        self.assertEqual(node.right.left.value, 10)
        self.assertEqual(node.right.right.value, 23)

        node = Node(5)
        node.left = Node(3)
        node.left.left = Node(1)
        node.left.right = Node(4)
        node.right = Node(15)
        node.right.left = Node(10)
        node.right.right = Node(23)
        delete(node, 15)
        self.assertEqual(node.value, 5)
        self.assertEqual(node.left.value, 3)
        self.assertEqual(node.left.left.value, 1)
        self.assertEqual(node.left.right.value, 4)
        self.assertEqual(node.right.value, 10)
        self.assertEqual(node.right.right.value, 23)
        delete(node, 5)
        self.assertEqual(node.value, 4)
        self.assertEqual(node.left.value, 3)
        self.assertEqual(node.left.left.value, 1)
        self.assertEqual(node.right.value, 10)
        self.assertEqual(node.right.right.value, 23)
        delete(node, 3)
        delete(node, 1)
        self.assertEqual(node.value, 10)
        self.assertEqual(node.left.value, 4)
        self.assertEqual(node.right.value, 23)


unittest.main(argv=[""], verbosity=2, exit=False)
