/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package org.joone.engine;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import java.util.Date;
import org.joone.engine.learning.TeachingSynapse;
import org.joone.net.NeuralNet;
import org.joone.structure.NodesAndWeights;
import org.joone.util.MonitorPlugin;

/**
 * A plugin listener that applies the RTRL algorithm to a network.
 * 
 * @author mg
 */
public class RTRLLearnerPlugin extends MonitorPlugin
{
	/** Version id */
	private static final long serialVersionUID = 1l;

	/** The RTRL that will do the training */
	protected transient RTRL rtrl;

	/** Network cycle counter */
	protected transient int cycleCount = 0;

	/** The most recently seen error */
	protected transient double lastError[];

	/** The synapse we will use to catch the errors */
	protected transient Synapse errorPatternListener;

	/** The previous error */
	protected transient double prevError;

	/** The internal learning rate */
	protected transient double learningRate;

	/** The minimum allowable learning rate */
	protected transient double minLearningRate = 1E-15;

	/** True if we are allowed to adapt the internal learning rate */
	protected transient boolean adaptLearningRate = false;

	/** Number of processors to use, 1 or less on uniprocessors */
	protected int processorCount = 0;

	/** The constructor */
	public RTRLLearnerPlugin ( boolean adaptLearningRate, int processorCount )
	{
		super ();
		setRate ( 1 );
		this.adaptLearningRate = adaptLearningRate;
		this.processorCount = processorCount;
	}

	/**
	 * Attach a synapse to the output layer to calculate the error pattern and
	 * update the RTRL
	 */
	protected void attachErrorPatternListener ()
	{
		// An unbuffered memory output synapse can also be used
		// Not sure, but hope a direct synapse will speed up
		// performance
		errorPatternListener = new DirectSynapse ()
		{
			/** Use to calculate the latest error */
			public void fwdPut ( Pattern output )
			{
				// Get the desired pattern from the teacher
				if (  ! output.isMarkedAsStoppingPattern () )
				{
					boolean ready = false;

					// We either get the error from the teacher of calculate it ourselves
					if ( getNeuralNet ().getTeacher () instanceof TeachingSynapse )
					{
						if ( ( ( TeachingSynapse ) getNeuralNet ().getTeacher () ).getTheTeacherSynapse ().isLastErrorPatternReady () )
						{
							lastError = ( ( TeachingSynapse ) getNeuralNet ().getTeacher () ).getTheTeacherSynapse ().getLastErrorPattern ();
							ready = true;
						}
					}
					else
					{
						Pattern desired = getNeuralNet ().getTeacher ().getDesired ().fwdGet ();

						if ( output.getCount () == desired.getCount () )
						{
							if ( lastError == null )
							{
								lastError = new double[ desired.getValues ().length ];
							}
							for ( int i = output.getValues ().length - 1; i >= 0; i -- ) // Do it this way for speed
							{
								lastError[i] = desired.getValues ()[i] - output.getValues ()[i];
							}
							ready = true;
						}
						else
						{
							System.err.println ( "RTRL : Pattern mismatch " + output.getCount () + " <> " + desired.getCount () );
						}
					}

					// Proceed if we have a valid pattern
					if ( ready )
					{
						rtrl.update ( lastError );
					}
				}

				// Make sure this does not wait for anything
				items = 0;
			}

		};

		// Configure and attach the synapse
		errorPatternListener.setName ( "RTRL" );
		errorPatternListener.setOutputDimension ( getNeuralNet ().getOutputLayer ().getRows () );
		getNeuralNet ().getOutputLayer ().addOutputSynapse ( errorPatternListener );
	}

	@Override
	protected void manageStop ( Monitor mon )
	{
		if ( errorPatternListener != null )
		{
			getNeuralNet ().getOutputLayer ().removeOutputSynapse ( errorPatternListener );
		}
	}

	@Override
	protected void manageCycle ( Monitor mon )
	{
		if ( adaptLearningRate )
		{
			if ( cycleCount == 1 )
			{
				prevError = mon.getGlobalError ();
				learningRate = mon.getLearningRate ();
			}
			else if ( cycleCount > 1 )
			{
				double currError = mon.getGlobalError ();

				// We are making progress along the current gradient
				if ( currError < prevError )
				{
					// Since things are going well, increase the step size
					learningRate *= 1.5;
				}
				else
				{
					// We are running away, so make the learning rate smaller
					learningRate = Math.max ( minLearningRate, Math.min ( learningRate * 0.5, mon.getLearningRate () ) );
				}

				// Prepare for the next round
				prevError = currError;

				// Shrink learning rate every now and again
				if ( cycleCount % 50 == 0 )
				{
					learningRate *= 0.1;
				}
			}
			rtrl.updateCycle ( learningRate );
			// System.err.println ( "\t\tRTRL LR " + learningRate );
		}
		else
		{
			if ( cycleCount == 1 )
			{
				learningRate = mon.getLearningRate ();
			}
			rtrl.updateCycle ( learningRate );
		}
		cycleCount ++;
	}

	@Override
	protected void manageStart ( Monitor mon )
	{
		if ( getNeuralNet ().getParam ( "maximumWeightMagnitude" ) != null && getNeuralNet ().getParam ( "maximumWeightMagnitude" ) instanceof Double )
		{
			rtrl = new RTRL ( new NodesAndWeights ( getNeuralNet (), ( Double ) getNeuralNet ().getParam ( "maximumWeightMagnitude" ), false ) );
		}
		else
		{
			rtrl = new RTRL ( new NodesAndWeights ( getNeuralNet (), 0, false ) );
		}
		rtrl.setProcessorCount ( processorCount );
		rtrl.nodesAndWeights.printWeights ( System.err );
		attachErrorPatternListener ();
	}

	@Override
	protected void manageError ( Monitor mon )
	{
	}

	@Override
	protected void manageStopError ( Monitor mon, String msgErr )
	{
		if ( errorPatternListener != null )
		{
			getNeuralNet ().getOutputLayer ().removeOutputSynapse ( errorPatternListener );
		}
	}

	/** Test the RTRL learner */
	public static void testRTRL ( String args[] )
	{
		if ( args.length == 0 )
		{
			System.err.println ( "Please specify the net to train on the command line" );
			return;
		}
		for ( String arg : args )
		{
			System.err.println ( "Now training " + arg );
			try
			{
				final NeuralNet network = ( NeuralNet ) new ObjectInputStream ( new FileInputStream ( arg ) ).readObject ();

				// Sanity check - needed since we set the monitor to null before we save our RTRL trained networks
				if ( network.getMonitor () == null )
				{
					network.setMonitor ( new Monitor () );
				}

				// Custom network settings

				// E.g. clear and fix the bias nodes in the test network as below
				if ( network.getLayer ( "Input layer" ) != null )
				{
					network.getLayer ( "Input layer" ).getBias ().clear ();
					network.getLayer ( "Input layer" ).getBias ().fixAll ();
				}
				if ( network.getLayer ( "Hidden layer" ) != null )
				{
					network.getLayer ( "Hidden layer" ).getBias ().clear ();
					network.getLayer ( "Hidden layer" ).getBias ().fixAll ();
				}
				if ( network.getLayer ( "Output layer" ) != null )
				{
					network.getLayer ( "Output layer" ).getBias ().clear ();
					network.getLayer ( "Output layer" ).getBias ().fixAll ();
				}

				// Keep 'em eyes peeled
				network.addNeuralNetListener ( new NeuralNetListener ()
						       {
							       public void netStarted ( NeuralNetEvent e )
							       {
								       System.err.println ( "\tNetwork started" );
							       }

							       public void cicleTerminated ( NeuralNetEvent e )
							       {
								       System.err.println ( "\tCycle " + ( network.getMonitor ().getTotCicles () - network.getMonitor ().getCurrentCicle () ) + " error " + network.getMonitor ().getGlobalError () );
							       }

							       public void netStopped ( NeuralNetEvent e )
							       {
								       System.err.println ( "\tNetwork stopped" );
							       }

							       public void errorChanged ( NeuralNetEvent e )
							       {
							       }

							       public void netStoppedError ( NeuralNetEvent e, String error )
							       {
								       System.err.println ( "\tNetwork error : " + error );
							       }

						       } );

				// Learner plugin
				RTRLLearnerPlugin rtrl = new RTRLLearnerPlugin ( true, 1 );
				rtrl.setNeuralNet ( network );
				network.addNeuralNetListener ( rtrl );

				// Just cycle the net, the plugin will do the rest
				network.getMonitor ().setLearningRate ( 0.00001 );   // Does need a learning rate
				network.getMonitor ().setTotCicles ( 1000 );
				network.getMonitor ().setLearning ( false );
				network.getMonitor ().setTrainingPatterns ( 0 );
				network.getMonitor ().setValidation ( true );
				network.getMonitor ().setValidationPatterns ( 1000 );

				// Now train the network in single thread mode
				network.go ( true, false );
				network.join ();
				rtrl.rtrl.nodesAndWeights.printWeights ( System.err );

				// Done, save network
				System.err.println ( "\tDone! Saving network" );

				// First remove the listener
				network.removeAllListeners ();

				// Also reset the monitor
				network.setMonitor ( null );

				// Remove any redundant listeners
				network.removeAllListeners ();

				// Now save it
				ObjectOutputStream out = new ObjectOutputStream ( new FileOutputStream ( arg + "-joone.rtrl" ) );
				out.writeObject ( network );
				out.flush ();
				out.close ();
			}
			catch ( Exception e )
			{
				System.err.println ( "Failed : " + e.getLocalizedMessage () );
				e.printStackTrace ();
			}
		}

	}

	/**
	 * Execute one of the tests
	 */
	public static void main ( String args[] )
	{
		System.err.println ( "Please specify the action to take" );
		System.err.println ();
		System.err.println ( "\trtrl <network 1> <network 2> ... <network n>" );
		System.err.println ();
		System.err.println ( "\t\tWill optimise network 1 .. n using the RTRL algorithm" );
		System.err.println ( "\t\tPlease check and set pattern and cycle counts in code." );
		System.err.println ( "\t\tTrained networks are saved postfixed with the string -rtrl" );
		System.err.println ();
		System.err.println ( "MG Ferreira" );
		System.err.println ( "2008" );
		System.err.println ();

		if ( args.length > 1 )
		{
			System.out.println ( "Started " + new Date () );
			long start = System.currentTimeMillis ();

			if ( args[ 0 ].equals ( "rtrl" ) )
			{
				testRTRL ( Arrays.copyOfRange ( args, 1, args.length ) );
			}

			long end = System.currentTimeMillis ();
			int sss = ( int ) ( ( end - start ) % 1000l );
			int s = ( int ) ( ( ( end - start ) / 1000l ) % 60l );
			int m = ( int ) ( ( ( ( end - start ) / 1000l ) / 60l ) % 60l );
			int h = ( int ) ( ( ( ( end - start ) / 1000l ) / 60l ) / 60l );

			System.out.println ( "Done" + new Date () );
			System.out.println ( "Time taken " + ( h < 10 ? "0" : "" ) + h + ":" + ( m < 10 ? "0" : "" ) + m + ":" + ( s < 10 ? "0" : "" ) + s + "'" + sss );
			System.out.println ( "MG Ferreira" );
			System.out.println ( "2008" );
		}
	}

}