#include <stdio.h>
#include <stdlib.h>
#include <astar.h>

// This prints a list of the nodes to the screen
void printNodeList(Node * list, char *name, void (*printfunc) (Node *));
void printNodeList(Node * list, char *name, void (*printfunc) (Node *))
{
  Node           *p;

  printf("Printing list %s\n", name);
  p = list;
  while (p != NULL) {
    printfunc(p);
    p = p->next;
  }
}

/* 
 * This is a generalized version of the A* search algorithm. It takes
 * as arguments a root node and a series of functions that are defined
 * by the user for a specific task.  The Node* structure is defined in
 * astar.h.  You can specify additional information for the node by
 * putting a pointer to your own structure in the nodeInfo field of a
 * Node* structure. Then you can access your personal structure when
 * you receive a Node*.  

 * This is a general A* function, so the user defines all of the
 * necessary function calls for calculating path costs and generating
 * the children of a particular node.  

 * root: Node*, root node from which to begin the search.

 * gcalc: Takes a Node* as argument and returns a double indicating
 * the g value (cost from initial state).

 * hcalc: Takes a Node* as argument and returns a double indicating
 * the h value (estimated cost to the goal).

 * goalNode: Takes a Node* as argument and returns 1 if the node is
 * the goal node, 0 otherwise.

 * children: Takes a Node* as argument and returns a linked list of
 * the children of that node, or NULL if there are none.

 * freeNode: Takes a Node* as argument and frees the memory associated
 * with the node.

 * nodeEqual: Takes two Node* as arguments and returns a 1 if the
 * nodes are equivalent states, 0 otherwise.

 * printNode: Takes a Node* as argument and prints the node
 * information to the stdout console.  

 * graphNode: Takes a Node* and a character as an argument and puts
 * the character at the grid location specified by the node.

 * The function returns a Node*, which is the root node of a linked
 * list (follow the "next" fields) that specifies the path from the
 * root node to the goal node.  In astar.h you can define the maximum
 * number of nodes to search. In the default distribution it is set to
 * 10000, which is under a second's worth of search time.  */

Node  *AStarSearch(Node * root, 
		   double (*gcalc) (Node *), 
		   double (*hcalc) (Node *), 
		   int (*goalNode) (Node *), 
		   Node * (*children) (Node *), 
		   void (*freeNode) (Node *), 
		   int (*nodeEqual) (Node *, Node *), 
		   void (*printNode) (Node *), 
		   void (*graphNode) (Node *, char))
{
  Node           *openList;
  Node           *closedList;
  Node           *current;
  Node           *childList;
  Node           *curChild;
  Node           *p, *q, *r, *s;
  Node           *path;
  int            gblID = 1;
  int            gblExpand = 0;

  // generate the open list
  openList = NULL;

  // generate the closed list
  closedList = NULL;

  // put the root node on the open list
  root->next = NULL;
  root->prev = NULL;
  openList = root;

  while (openList != NULL) {
    // remove the first node from the open list, as it will always be sorted

    current = openList;
    openList = (Node *) openList->next;

    if(openList != NULL)
      openList->prev = NULL;

    gblExpand++;

    graphNode(current, 'x');

    // is the current node the goal node?
    if (goalNode(current)) {

      // build the complete path to return
      current->next = NULL;
      path = current;
      p = (Node *) current->parent;

      printf("Goal state reached with %d nodes created and %d nodes expanded\n", 
	     gblID, gblExpand);

      while (p != NULL) {

	// remove the parent node from the closed list (where it has to be)
	if (p->prev != NULL)
	  ((Node *) p->prev)->next = p->next;
	if (p->next != NULL)
	  ((Node *) p->next)->prev = p->prev;

	// check if we're romoving the top of the list
	if (p == closedList)
	  closedList = p->next;

	// set it up in the path
	p->next = (void *) path;
	path = p;

	p = (Node *) p->parent;
      }

      // now delete all nodes on OPEN
      while (openList != NULL) {
	p = (Node *) openList->next;
	freeNode(openList);
	openList = p;
      }

      // now delete all nodes on CLOSED
      while (closedList != NULL) {
	p = (Node *) closedList->next;
	freeNode(closedList);
	closedList = p;
      }

      // now return the path
      return (path);
    }

    // now expand the current node
    childList = children(current);

    // insert the children into the OPEN list according to their f values
    while (childList != NULL) {
      curChild = childList;
      childList = (Node *) childList->next;

      // set up the rest of the child node
      curChild->parent = (void *) current;
      curChild->state = OPEN;
      curChild->depth = current->depth + 1;
      curChild->id = gblID++;
      curChild->next = NULL;
      curChild->prev = NULL;

      // calculate the f value as f = g + h
      curChild->g = gcalc(curChild);
      curChild->h = hcalc(curChild);
      curChild->f = curChild->g + curChild->h;

      // test whether the child is in the closed list (already been there)
      if (closedList != NULL) {
	p = closedList;
	while (p != NULL) {
	  if (nodeEqual(p, curChild)) {// if so, check if the f value is lower

	    if (p->f <= curChild->f) { // if the f value of the older
	                               // node is lower, delete the new child

	      freeNode(curChild);
	      curChild = NULL;
	      break;
	    }
	    else {			       
	      // the child is a shorter path to this point, so remove p from the closed list
	      // along with all of its children

	      // This works so long as the new child is put in the OPEN list
	      // Another solution is to just update all of the
	      // descendents of this node with the new f values.
	      if (p->prev != NULL)
		((Node *) p->prev)->next = p->next;

	      if (p->next != NULL)
		((Node *) p->next)->prev = p->prev;

	      if (p == closedList)
		closedList = p->next;

	      // either need to delete this node and all of it's descendents
	      // or update their f values
	      p->prev = NULL;
	      p->parent = NULL;
	      p->depth = -9999; // indication that it's a dead node

	      // go through the closed list nodes
	      r = closedList;
	      while(r != NULL) {

		// check if this node is a descendent of p or its children
		q = r;
		while(q->parent != NULL) {
		  q = q->parent;
		}
		if(q->depth == -9999) { // we backed into a dead node
		  q = r;
		  while(q->parent != NULL) { // cleanup all nodes along the path
		    q->depth = -9999;
		    s = q->parent;
		    q->parent = NULL;
		    q = s;
		  }
		}

		r = r->next;
	      }

	      // go through the open list nodes
	      r = openList;
	      while(r != NULL) {

		// check if this node is a descendent of p or its children
		q = r;
		while(q->parent != NULL) {
		  q = q->parent;
		}
		if(q->depth == -9999) { // we backed into a dead node
		  q = r;
		  while(q->parent != NULL) { // cleanup all nodes along the path
		    q->depth = -9999;
		    s = q->parent;
		    q->parent = NULL;
		    q = s;
		  }
		}

		r = r->next;
	      }

	      // go through the rest of the child list and delete any with -9999 depth parents
	      // delete all nodes with depth -9999
	      r = childList;
	      q = NULL;
	      while(r != NULL) {
		s = r->next;

		if(((Node *)r->parent)->depth == -9999) {  // parent-less node

		  if(q == NULL)
		    childList = r->next;
		  else
		    q->next = r->next;

		  freeNode(r);
		}
		else
		  q = r; // move back pointer along

		r = s;
	      }

	      // delete all nodes with depth -9999
	      r = closedList;
	      while(r != NULL) {
		q = r->next;
		if(r->depth == -9999) {

		  if (r->prev != NULL)
		    ((Node *) r->prev)->next = r->next;

		  if (r->next != NULL)
		    ((Node *) r->next)->prev = r->prev;

		  if (r == closedList)
		    closedList = r->next;
		  
		  freeNode(r);
		}

		r = q;
	      }

	      // do the same for the open list
	      r = openList;
	      while(r != NULL) {
		q = r->next;
		if(r->depth == -9999) {

		  if (r->prev != NULL)
		    ((Node *) r->prev)->next = r->next;

		  if (r->next != NULL)
		    ((Node *) r->next)->prev = r->prev;

		  if (r == openList)
		    openList = r->next;
		  
		  freeNode(r);
		}

		r = q;
	      }

	      freeNode(p);

	      break;
	    }
	  }
	  p = p->next;
	}
      }

      if (curChild != NULL) {
	// check if the child is already on the open list
	p = openList;
	while (p != NULL) {

	  // child is on the OPEN list
	  if (nodeEqual(p, curChild)) {

	    // child is a longer path to the same place so delete it
	    if (p->f <= curChild->f) {

	      freeNode(curChild);
	      curChild = NULL;
	      break;
	    }
	    else {
	      // child is a shorter path to the same place
	      // remove the duplicate node

	      if (p->prev != NULL)
		((Node *) p->prev)->next = p->next;
	      if (p->next != NULL)
		((Node *) p->next)->prev = p->prev;
	      if (p == openList)
		openList = p->next;

	      // an open list node has no children to worry about
	      freeNode(p);

	      break;
	    }
	  }
	  p = p->next;
	}

	if (curChild != NULL) {
	  // now insert the child into the list according to the f values
	  p = openList;
	  q = p;
	  while (p != NULL) {
	    if (p->f >= curChild->f) {	       // insert before p
	      // test head case

	      if (p == openList)
		openList = curChild;

	      // insert the node
	      curChild->next = (void *) p;
	      curChild->prev = (void *) p->prev;
	      p->prev = curChild;
	      if (curChild->prev != NULL)
		((Node *) curChild->prev)->next = curChild;
	      break;
	    }
	    q = p;
	    p = p->next;
	  }

	  // insert at the end
	  if (p == NULL) {

	    if (q != NULL) {
	      q->next = (void *) curChild;
	      curChild->prev = q;
	    }
	    // insert at the beginning
	    else
	      openList = curChild;
	  }
	}  // end if child is not NULL (better duplicate not on OPEN list)

      }	 // end if child is not NULL (better duplicate not on CLOSED list)

    }  // end of child list loop

    // put the current node onto the closed list
    current->next = (void *) closedList;
    if (closedList != NULL)
      closedList->prev = current;
    closedList = current;

    current->prev = NULL;
    current->state = CLOSED;

    // Test to see if we have expanded too many nodes without a solution
    if (current->id > MAXNODES) {
      printf("Expanded more than the maximum allowable nodes. Terminating\n");

      // delete all nodes on OPEN
      while (openList != NULL) {
	p = (Node *) openList->next;
	freeNode(openList);
	openList = p;
      }

      // delete all nodes on CLOSED
      while (closedList != NULL) {
	p = (Node *) closedList->next;
	freeNode(closedList);
	closedList = p;
      }

      return (NULL);
    }
  }  // end of OPEN loop

  // if we got here, then there is no path to the goal
  printf("No path to goal\n");

  // delete all nodes on CLOSED since OPEN is now empty
  while (closedList != NULL) {
    p = (Node *) closedList->next;
    freeNode(closedList);
    closedList = p;
  }

  return (NULL);
}
