// File: hand.cp
//
// Description: Handwriting recognition system.
//              This demo use the TNeural Class
//

#include "applib.h"
#include "tneural.h"
#include "tplot.h"
#include "int_vect.h"   // Use smart vector to store points
#include "double_v.h"   // Use smart vector to store points

int nump = 0;    // count the number of points stored
int_vect xp(500), yp(500);

static TPlotClass *tp = (TPlotClass *)NULL;
const int MAX_ERROR = 30;
static double xp_error[MAX_ERROR], yp_error[MAX_ERROR];  // For error plot
static int num_error = 0;

//  USER DEFINED CALLBACK(S):

// Keep track of current mode of operation:
//  0:  do nothing
//  1:  recording training data
//  2:  currently training network
//  3:  testing recognition

static int current_mode = 0;

// Data for neural network: 7x9 input, 10 output neurons

TNeuralNet nnet(63,6,10); // 63 inputs, 6 hidden, 10 output neurons

double_vect input_examples(1000), output_examples(200);
static int num_training_examples = 0;
static int training_input_counter = 0;
static int training_output_counter = 0;

double in_vect[63];

char *symbols[] = {"0","1","2","3","4","5","6","7","8","9"};

// Utility to convert captured points that were drawn with the
// mouse into a small 63 element vector (63 = 9x7):

void convert_to_2D(TAppWindow *tw)
{
    int min_x = 9999; int max_x = -1; int min_y = 9999; int max_y = -1;
    for (int i=0; i<nump; i++) {
        if (min_x > xp[i]) min_x = xp[i];
        if (max_x < xp[i]) max_x = xp[i];
        if (min_y > yp[i]) min_y = yp[i];
        if (max_y < yp[i]) max_y = yp[i];
    }
    if (min_x >= max_x) min_x = max_x - 1;
    if (min_y >= max_y) min_y = max_y - 1;
    for (i=0; i<63; i++) in_vect[i] = 0.1;
    for (i=0; i<nump; i++) {
        int x = (int) (((float)(xp[i] - min_x) / (float)(max_x - min_x)) * 7.0);
        int y = (int) (((float)(yp[i] - min_y) / (float)(max_y - min_y)) * 9.0);
        if (x < 0) x = 0;  if (x > 6) x = 6;
        if (y < 0) y = 0;  if (y > 8) y = 8;
        in_vect[x*9+y] = 0.9;
        tw->plot_line(10+x,70+y,10+x,70+y);
    }
}

static int training_step = 0;

static void save_training_data(TAppWindow *tw)
{
    convert_to_2D(tw);
    for (int i=0; i<63; i++)
        input_examples[training_input_counter++] = in_vect[i];
    for (i=0; i<10; i++)
        if (i != training_step)
            output_examples[training_output_counter++] = 0.1;
        else
            output_examples[training_output_counter++] = 0.9;   
    num_training_examples++;
}

static int recall_value = 0;

void recall_pattern(TAppWindow *tw)
{
    convert_to_2D(tw);
    nnet.forward_pass(in_vect);
    // Find the largest output activation:
    double max_act = -9999.0;
    int max_index = 0;
    for (int i=1; i<10; i++)
        if (nnet.output_activations[i] > nnet.output_activations[max_index])
            max_index = i;
    recall_value = max_index;
}

void TAppWindow::idle_proc() { }

void TAppWindow::update_display()
{
    if (current_mode == 1) {
        char buf[64];
        sprintf(buf,"Recording %s", symbols[training_step]);
        plot_string(2,55,buf);
    }
    if (current_mode == 2) {
        plot_string(2,55,"Training");
    }
    if (current_mode == 3) {
        plot_string(2,55,"Testing");
    }
    if (tp != (TPlotClass *)NULL)
       tp->plot();
}

static int down_flag = 0;

void TAppWindow::mouse_down(int x, int y)
{
    down_flag = 1;
    if (current_mode == 1 || current_mode == 3) {
        if (x < 40 || y < 40) {
            if (current_mode == 1) {
                save_training_data(this);
                nump = 0;
//              for (long delay=0; delay<456000; delay++) ;
                clear_display();
                training_step += 1;
                if (training_step < 10) {
                    nump = 0;
                    clear_display();
                    plot_string(30,70,symbols[training_step]);
                }
                else current_mode = 0;
                return;
            }
            if (current_mode == 3) {
                recall_pattern(this);
                nump = 0;
                current_mode = 0;
                clear_display();
                plot_string(50,70,symbols[recall_value]);
                return;
            }
        }
        xp[nump] = x;
        yp[nump] = y;
        nump++;
        plot_line(x,y,x+1,y+1);
    }
}

void TAppWindow::mouse_up(int, int)
{
    down_flag = 0;
}

void TAppWindow::mouse_move(int x, int y)
{
    if (down_flag)  mouse_down(x, y);
}

void TAppWindow::do_menu_action(int item_number)
{
    if (item_number == 1) { // Record training data
        current_mode = 1;
        training_step = 0;
        nump = 0;
        clear_display();
        plot_string(30,70,symbols[0]);

        num_error = 0;

        for (int k=0; k<MAX_ERROR; k++) {


            xp_error[k] = k;


            yp_error[k] = 0.0;

        }
    }
    if (item_number == 2) { // Train network and save weights
        for (int j=0; j<10; j++) {
            double error = nnet.auto_train(100,num_training_examples,
                            &(input_examples[0]), &(output_examples[0]));

    yp_error[num_error] = error;

    if (num_error < (MAX_ERROR - 2))  num_error++;

    tp->rescale();

    tp->plot();
        }
        char output_filename[255];
        if (choose_file_to_write("Output file name (ending with .net):",
                                  output_filename) != 0) {
            Warning("Could not open/write to file name");
        }  else  {
            nnet.save(output_filename);
        }
        current_mode = 2;
    }
    if (item_number == 3) { // Reload saved weights
    char input_filename[255];
        if (choose_file_to_read("Weight file:", "net", input_filename) != 0) {
            Warning("Could not open/read to file name");
        }  else  {
            nnet.restore(input_filename);
        }
    }
    if (item_number == 4) { // test recognition
        current_mode = 3;
    }
}

static char *m_titles[] ={"Record training data",
	                  "Train network and save weights",
	                  "Reload saved weights",
                          "Test recognition"};

INIT_PROGRAM("Handwriting recognition", 4, m_titles)

   // anything can go here
   for (int k=0; k<MAX_ERROR; k++) {
       xp_error[k] = k;
       yp_error[k] = k;
   }
   tp = new TPlotClass(current_window,"Error (RMS) Plot",
                       xp_error,yp_error,MAX_ERROR,


       200,400,300,45);
   
   RUN_PROGRAM;
}
