/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */

package org.joone.engine;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.joone.structure.NodesAndWeights;

/**
 * A RTRL implementation.
 *
 * Based mostly on http://www.willamette.edu/~gorr/classes/cs449 and a few others.
 *
 * A partial RTRL implementation. Network weights are optimised using an offline RTRL
 * implementation. The initial states of context nodes are not optimised, but could
 * easily be added. For now, initial states are simply assumed to be what they are
 * set to be in the context layer itself.
 *
 * RTRL does not rely on a backpropagated error and this can and should be turned
 * off, in order to speed things up. Functionality for this is included in the
 * Monitor class and turned off whenever the setMonitor message is called.
 *
 * In order to speed things up, this includes an experimental lineseek approach where
 * firstly the gradient is calculated using the offline RTRL algorithm. Then a step is
 * taken along the gradient for as long as the sum of squared errors decreases in a
 * typical lineseek type fashion. As soon as a step results in an increased sum of
 * squared errors, a step back is taken, typically smaller than the step forward,
 * and the gradient is once again updated. The stepping up and down is scaled in the
 * spirit of the RPROP algorithm, so that the learning rate is adjusted after each cycle.
 *
 * Weights can also be randomised in the spirit of simulated annealing at the end of
 * each cycle. As with the above lineseek approach, see the constructor for more details.
 * These two features were really added to try and speed up convergence - if at all! Their
 * practical benefit remain highly suspect at best.
 *
 * This class has a main method which also serves as a demo of the RTRL. Please refer to
 * that. A suitable net can easily be created using the GUI and then trained using the
 * main method, with a few alterations to the code based on the number of patterns for
 * example, which, amongst others, is currently hard coded.
 *
 * The main method also shows how to save and restore a network trained via RTRL. While
 * this class does implement the Serializable interface, it is highly suspect and not
 * meant to be serialised together with the network.
 *
 * This implementation is highly academic at present. Any good exmaples where this can
 * be applied will be much appreciated. I am still looking for them. The initial conditions
 * as well as the learning rate seems to have such a high impact on the convergence of
 * this as to make it of almost no practical use it seems. Also, strangely, it often
 * seems that a higher rather than lower learning rate is better for convergence.
 *
 * Support for multiprocessors have now been added.
 * 
 * @author mg, http://www.ferra4models.com
 */
public class RTRL
{
	/** The network we are training */
	protected NodesAndWeights nodesAndWeights;

	/** The p matrix, p [ k ] [ij ] is node k's (in U) derivative with respect to weight ij */
	protected double p [][];

	/** The utility updateP matrix, used when updating the p matrix */
	protected double updateP [][];

	/** Pattern counter */
	protected int patternCount = 0;

	/** Number of processors to use, 1 or less on a uniprocessor */
	protected int processorCount = 0;

	/** List of list of nodes that will be updated by each processor */
	protected List < List < NodesAndWeights.Node > > nodesList = null;

	/** List of list of weights that will be updated by each processor */
	protected List < List < NodesAndWeights.Weight > > weightsList = null;

	/**
	 * Create a new instance of RTRL
	 * @param nodesAndWeights the network to be optimised's structure
	 */
	public RTRL ( NodesAndWeights nodesAndWeights )
	{
		this.nodesAndWeights = nodesAndWeights;
		init ();
	}

	/** Initialise */
	protected void init ()
	{
		// Initialise the p matrix, this also fills it with zeroes, which is what we want.
		p = new double [ nodesAndWeights.U.size () ][ nodesAndWeights.weights.size () ];
		updateP = new double [ nodesAndWeights.U.size () ][ nodesAndWeights.weights.size () ];
	}

	/** Set the number of processors to use */
	public void setProcessorCount ( int processorCount )
	{
		this.processorCount = processorCount;

		// Allocate nodes and weights to each processor
		nodesList = null;
		weightsList = null;
		if ( processorCount > 1 )
		{
			// Allocate the list of lists
			nodesList = new ArrayList < List < NodesAndWeights.Node > > ();
			weightsList = new ArrayList < List < NodesAndWeights.Weight > > ();

			// Allocate the lists
			for ( int i = 0; i < processorCount; i ++ )
			{
				nodesList.add ( new ArrayList < NodesAndWeights.Node > () );
				weightsList.add ( new ArrayList < NodesAndWeights.Weight > () );
			}

			// Allocate the nodes to each list
			for ( NodesAndWeights.Node node : nodesAndWeights.U )
			{
				nodesList.get ( node.K % processorCount ).add ( node );
			}

			// Allocate the weights to each list
			for ( NodesAndWeights.Weight weight : nodesAndWeights.weights )
			{
				weightsList.get ( weight.w % processorCount ).add ( weight );
			}
		}
	}

	/** Retrieve processor count */
	public int getProcessorCount ()
	{
		return processorCount;
	}

	/** Update the p matrix for the given list of nodes */
	protected void updateP ( List < NodesAndWeights.Node > nodes )
	{
		// All k in the set of nodes
		for ( NodesAndWeights.Node k : nodes )
		{
			double df = k.getDerivative ();

			// All applicable i and j
			for ( NodesAndWeights.Weight ij : nodesAndWeights.weights )
			{
				double sum = 0;

				// I in U
				for ( NodesAndWeights.Weight I : k.I )
				{
					sum += I.getWeight () * p [ I.J ][ ij.w ];
				}

				// Kronecker delta portion
				if ( ij.I == k.K )
				{
					sum += nodesAndWeights.Z.get ( ij.J ).getValue ();
				}

				// Apply derivative and store back into the updated p
				updateP [ k.K ][ ij.w ] = sum * df;
			}
		}
	}

	/** Update the p matrix - called after a pattern has been presented to the network */
	protected void updateP ()
	{
		if ( nodesList == null )
		{
			// Single update step on single processor for all nodes in U
			updateP ( nodesAndWeights.U );
		}
		else
		{
			// Each processor gets a list to work on
			List < Thread > threadList = new ArrayList < Thread > ();
			for ( final List < NodesAndWeights.Node > nodeList : this.nodesList )
			{
				threadList.add ( new Thread ( new Runnable ()
				{
					public void run ()
					{
						updateP ( nodeList );
					}
				} ) );
				threadList.get ( threadList.size () - 1 ).start ();
			}

			// Wait until all the threads complete
			for ( Thread thread : threadList )
			{
				try
				{
					thread.join ();
				}
				catch ( InterruptedException ex )
				{
					System.err.println ( "Interrupted" );
					ex.printStackTrace ();
				}
			}
		}

		// Now we swap p and updateP around
		double t [][] = p;
		p = updateP;
		updateP = t;
	}

	/**
	 * Update the given weights' deltas.
	 *
	 * This is called once a pattern has been presented to the network.
	 * 
	 * @param error most recently seen error pattern
	 */
	protected void updateDeltas ( double error [], List < NodesAndWeights.Weight > weights )
	{
		// Update each weight's delta
		for ( NodesAndWeights.Weight ij : weights )
		{
			double delta = 0;

			for ( NodesAndWeights.Node k : nodesAndWeights.T )
			{
				delta += error [ k.K ] * p [ k.K ][ ij.w ];
			}

			ij.addDelta ( delta );
		}
	}

	/**
	 * Update the weights' deltas.
	 *
	 * This is called once a pattern has been presented to the network.
	 * 
	 * @param error most recently seen error pattern
	 */
	protected void updateDeltas ( final double error [] )
	{
		// Only one call on a uniprocessor
		if ( weightsList == null )
		{
			updateDeltas ( error, nodesAndWeights.weights );
		}
		else
		{
			// Each processor gets a list to work on
			List < Thread > threadList = new ArrayList < Thread > ();
			for ( final List < NodesAndWeights.Weight > weightList : weightsList )
			{
				threadList.add ( new Thread ( new Runnable ()
				{
					public void run ()
					{
						updateDeltas ( error, weightList );
					}
				} ) );
				threadList.get ( threadList.size () - 1 ).start ();
			}

			// Wait until all the threads complete
			for ( Thread thread : threadList )
			{
				try
				{
					thread.join ();
				}
				catch ( InterruptedException ex )
				{
					System.err.println ( "Interrupted" );
					ex.printStackTrace ();
				}
			}
		}
	}

	/** Reset the p matrix in preparation for the next cycle - called at the end of a cycle */
	protected void resetP ()
	{
		// k in U
		for ( NodesAndWeights.Node k : nodesAndWeights.U )
		{
			// All ij in W
			Arrays.fill (  p [ k.K ], 0 );
		}
	}

	/**
	 * Update RTRL
	 * 
	 * Call this with the most recent error pattern as soon as one
	 * becomes available.
	 * 
	 * @param error the most recently seen error pattern
	 */
	public void update ( double error [] )
	{
		updateP ();
		updateDeltas ( error );
		patternCount ++;
	}

	/**
	 * Update the weights
	 * 
	 * Call this once a full set of patterns were presented to the network
	 * to update the weights
	 */
	public void updateCycle ( double learningRate )
	{
		// Update all weights and then reset the deltas
		for ( NodesAndWeights.Weight weight : nodesAndWeights.weights )
		{
			weight.addWeight ( learningRate * weight.getDelta () );
			weight.setDelta ( 0 );
		}

		// Reset the p matrix
		resetP ();
	}

	/**
	 * Helper to print out the p matrix
	 * @param out stream to which to dump the matrix
	 */
	public void printP ( PrintStream out )
	{
		out.println ( "p matrix" );
		for ( NodesAndWeights.Weight weight : nodesAndWeights.weights )
		{
			out.println ( "\tColumn " + ( weight.w + 1 ) + " refers to connection to " + weight.I + " from node " + weight.J );
		}
		for ( int k = 0; k < p.length; k ++ )
		{
			out.println ( "\tNode " + k + " " + Arrays.toString ( p [ k ] ) );
		}
	}
}