/*---------------------------------

  BS_Tree.h

  (P) 1998 Laurentiu Cristofor

  interface and implementation
  for a generic binary search tree

-----------------------------------*/

#ifndef __BS_TREE_H_
#define __BS_TREE_H_

#include <vector>
#include <iterator>
#include <functional> // for less<T>
#include <algorithm>  // for binary_search()

template<class T> 
struct Tree_Node;

// the BST class
template<class T, class Cmp = less<T> > 
class BS_Tree
{
  typedef vector<T>::const_iterator const_iterator;

  // Function object for comparisons (<)
  Cmp cmp;

  Tree_Node<T>* root;

  // this is modified by the following:
  // deleteBST(), insert(), remove()
  // and initializaed to zero in the
  // constructor and the copy constructor
  int bst_size;

  // for iterators
  vector<T> node_vector;

  // if true this means that
  // node vector must be updated
  // because the contents of 
  // the BST have changed
  bool vector_dirty;

  void deleteBST(Tree_Node<T>* pnode);

  void copyBST(Tree_Node<T>* pnode);

  void update_node_vector(Tree_Node<T>* pnode);

  void synchronize_vector();

public:
  BS_Tree()
  {
    bst_size     = 0;
    root         = 0;
    vector_dirty = true;
  }

  BS_Tree(BS_Tree& bst)
  {
    bst_size     = 0;
    root         = 0;
    vector_dirty = true;
    copyBST(bst.root);
  }

  BS_Tree& operator= (BS_Tree& bst)
  {
    deleteBST(root);
    copyBST(bst.root);
    vector_dirty = true;

    return *this;
  }

  ~BS_Tree()
  {
    deleteBST(root);
  }

  int size()
  {
    return bst_size;
  }
  
  bool empty()
  {
    return bst_size == 0;
  }
  
  // Exchanges the contents of a BS_Tree
  // with the contents of another BS_Tree
  static void swap(BS_Tree& bst1, BS_Tree& bst2)
  {
    Tree_Node<T>* rtemp = bst1.root;
    int           stemp = bst1.bst_size;

    bst1.root     = bst2.root;
    bst2.root     = rtemp;

    bst1.bst_size = bst2.bst_size;
    bst2.bst_size = stemp;

    bst1.vector_dirty = bst2.vector_dirty = true;
  }

  // Exchanges the contents of a BS_Tree
  // with the contents of another BS_Tree
  void swap(BS_Tree& bst)
  {
    Tree_Node<T>* rtemp = root;
    int           stemp = bst_size;

    root         = bst.root;
    bst.root     = rtemp;

    bst_size     = bst.bst_size;
    bst.bst_size = stemp;

    vector_dirty = bst.vector_dirty = true;
  }

  // inserts a value into the BST
  void insert(const T& value);

  // looks for and removes value from the BST;
  // returns true if it has found value,
  // false if value was not found
  bool remove(const T& value);

  // returns true if value exists in a tree
  bool contains(const T& value);

  // finds the first occurrence of value
  // in the BST and returns an iterator to it;
  // if the value is not found, an iterator to
  // past the last value is returned (same as
  // the one returned by end())
  const_iterator find(const T& value);

  const_iterator begin()
  {
    if (vector_dirty)
      synchronize_vector();
    
    return node_vector.begin();
  }

  const_iterator end()
  {
    if (vector_dirty)
      synchronize_vector();
    
    return node_vector.end();
  }
};

////////////////////////////////////////////////////

/*------------------------------------------------
  simple recursive function for deleting the tree
--------------------------------------------------*/
template<class T, class Cmp = less<T> > 
void BS_Tree<T, Cmp>::deleteBST(Tree_Node<T>* pnode)
{
  if (pnode == 0)
    return;
  else
    {
      deleteBST(pnode->left);
      deleteBST(pnode->right);
      delete pnode;
      pnode = 0;
      --bst_size;
    }
}

/*---------------------------------------------
  simple recursive function for copying a tree
-----------------------------------------------*/
template<class T, class Cmp = less<T> > 
void BS_Tree<T, Cmp>::copyBST(Tree_Node<T>* pnode)
{
  if (pnode == 0)
    return;
  else
    {
      insert(pnode->value);
      copyBST(pnode->left);
      copyBST(pnode->right);
    }
}

/*------------------------------------------------
  simple recursive function for adding the 
  values found in the tree nodes to node_vector
--------------------------------------------------*/
template<class T, class Cmp = less<T> > 
void BS_Tree<T, Cmp>::update_node_vector(Tree_Node<T>* pnode)
{
  if (pnode == 0)
    return;
  else
    {
      update_node_vector(pnode->left);
      node_vector.push_back(pnode->value);
      update_node_vector(pnode->right);
    }
}

/*------------------------------
  this function synchronizes 
  the contents of the vector 
  with the contents of the tree
--------------------------------*/
template<class T, class Cmp = less<T> > 
void BS_Tree<T, Cmp>::synchronize_vector()
{
  // erase node_vector only if necessary
  if (!node_vector.empty())
    node_vector.erase(node_vector.begin(), node_vector.end());

  node_vector.reserve(bst_size);

  update_node_vector(root);
  vector_dirty = false;
}

/*--------------------------------------
  use STL binary_search() to do the job
----------------------------------------*/
template<class T, class Cmp = less<T> > 
bool BS_Tree<T, Cmp>::contains(const T& value)
{
  if (vector_dirty)
    synchronize_vector();
    
  return binary_search(node_vector.begin(), node_vector.end(), value, cmp);
}

/*-----------------------------
  use STL find() to do the job
-------------------------------*/
template<class T, class Cmp = less<T> > 
BS_Tree<T, Cmp>::const_iterator BS_Tree<T, Cmp>::find(const T& value)
{
  if (vector_dirty)
    synchronize_vector();
    
  return std::find(node_vector.begin(), node_vector.end(), value);
}

/*-------------------------------------------------------
  simple algorithm for insertion in a binary search tree 
---------------------------------------------------------*/
template<class T, class Cmp = less<T> > 
void BS_Tree<T, Cmp>::insert(const T& value)
{
  Tree_Node<T>* parent  = root;
  Tree_Node<T>* current = root;

  // the following remembers if
  // current is left child of parent
  bool toLeft;

  vector_dirty = true;
  ++bst_size;
  
  if (root == 0)
    {
      root = new Tree_Node<T>(value);
      return;
    }

  while (true)
    {
      if (current == 0)
	{
	  if (toLeft)
	    parent->left  = new Tree_Node<T>(value);
	  else
	    parent->right = new Tree_Node<T>(value);

	  return;
	}
	  
      parent = current;
	  
      if (cmp(value, current->value))
	{
	  current = current->left;
	  toLeft  = true; 
	}
      else
	{
	  current = current->right;
	  toLeft  = false;
	}
    }
}

/*---------------------------------------------------
  The algorithm used for removing values from
  a binary search tree is the one described by
  Robert Sedgewick in his book: "Algorithms in C++".

  Since my root node is implemented a little 
  differently than in the book, I had to treat
  the removal of the root node as a special case.
-----------------------------------------------------*/
template<class T, class Cmp = less<T> > 
bool BS_Tree<T, Cmp>::remove(const T& value)
{
  Tree_Node<T>* parent    = root;
  Tree_Node<T>* current   = root;
  Tree_Node<T>* condemned = 0;

  // the following remembers if
  // current is left child of parent
  bool toLeft;
  
  // first look for node containing value
  while (current != 0 && 
	 (cmp(value, current->value) ||
	  cmp(current->value, value)))
    // we implement (a != b) as (a < b || b < a)
    {
      parent = current;
	  
      if (cmp(value, current->value))
	{
	  current = current->left;
	  toLeft  = true; 
	}
      else
	{
	  current = current->right;
	  toLeft  = false;
	}
    }

  // node not found
  if (current == 0)
    return false;

  vector_dirty = true;
  --bst_size;

  // store node that has to be deleted
  condemned = current;

  // treat root as a special case
  if (condemned == root)
    {
      if (condemned->right == 0)
	root = condemned->left;
      else if (condemned->right->left == 0)
	{
	  root       = condemned->right;
	  root->left = condemned->left;
	}
      // we have to deal with the hard case, the plan is:
      // we'll look for the leftmost child
      // in the right subtree of the condemned node
      // and move it in the place of condemned node.
      else
	{
	  current = condemned->right;
	  
	  while (current->left->left != 0)
	    current = current->left;

	  // current is now pointing to the
	  // parent of the leftmost node
	  
	  root          = current->left;
	  current->left = current->left->right;
	  root->left    = condemned->left;
	  root->right   = condemned->right;
	}
    }
  else
    {
      if (condemned->right == 0)
	{
	  if (toLeft)
	    parent->left  = condemned->left;
	  else
	    parent->right = condemned->left;
	}
      else if (condemned->right->left == 0)
	{
	  current = condemned->right;
	  
	  if (toLeft)
	    parent->left  = condemned->right;
	  else
	    parent->right = condemned->right;
	  
	  current->left = condemned->left;
	}
      else
	{
	  current = condemned->right;
	  
	  while (current->left->left != 0)
	    current = current->left;
	  
	  if (toLeft)
	    {
	      parent->left   = current->left;
	      current->left  = current->left->right;
	      
	      current        = parent->left;
	      current->left  = condemned->left;
	      current->right = condemned->right;
	    }
	  else
	    {
	      parent->right  = current->left;
	      current->left  = current->left->right;
	      
	      current        = parent->right;
	      current->left  = condemned->left;
	      current->right = condemned->right;
	    }
	}
    }
  
  // final act
  delete condemned;

  return true;
}

template<class T> 
struct Tree_Node
{
  Tree_Node *left;
  Tree_Node *right;
  
  // what we store in node
  T value;
  
  Tree_Node(const T& val)
  {
    value = val;
    left  = 0;
    right = 0;
  }
};

#endif// __BS_TREE_H_

