;;;; Program for training a perceptron network (single layer of linear
;;;; threshold units).

#|
How to use this program:

(load "...-data")
(init-perceptron)
(train-perceptron)
(test-perceptron ...)
|#

(defvar *show-perceptron-detail* nil)
(defvar *perceptron-data*)
;(defvar *nbr-inputs*)			; includes bias input
;(defvar *nbr-outputs*)
(defvar *weights*)

;;; User routines.

(defun init-perceptron ()
  (setq *perceptron-data* (convert-data-to-nnet-form))
  (setq *weights* (make-list-of-lists *nbr-outputs* *nbr-inputs* 0))
  (format t
    "~%Building net with ~a input unit~:p (plus bias) and ~
     ~a output unit~:p.~%Weights initialized to zero."
    (1- *nbr-inputs*)
    *nbr-outputs*)
  'done
  )

(defun train-perceptron (&optional (max-passes 1000))
  (dotimes (i *nbr-outputs* (values))
    (train-wts i max-passes)))

(defun test-perceptron (input-att-vec)
  (when (check-att-vector input-att-vec)
	(let* ((input (convert-input-to-euclid-vec input-att-vec))
	       (output (mapcar #'(lambda (wts) (threshold-logic wts input))
			       *weights*)))
	  (show-layer "Input: " input)
	  (show-layer "Output:" output)
	  (prettify-output (convert-output-to-att-vec output))
	  )))

(defun show-perceptron-weights ()
  (let ((i -1))
    (dolist (wts *weights* (values))
      (format t "~%Output unit ~a:~%~a" (incf i) wts))))

;;; Auxiliaries

(defun train-wts (i max-passes)
  (let* ((wts (nth i *weights*))
	 (passes 0)
	 (wt-changes 0))
    (let ((att-and-val (nth i *output-semantics*)))
      (format t "~%Learning to recognize ~a" (second att-and-val))
      (if (> (length *output-ranges*) 1)
	  (format t " for output attribute number ~a" (first att-and-val))))
    (loop
      (let ((change-made nil))
	(dolist (p *perceptron-data*)
	  (let* ((input (first p))
		 (desired-output (nth i (second p)))
		 (actual-output (threshold-logic wts input))
		 (error (- desired-output actual-output)))
	    (unless (zerop error)
	      (if *show-perceptron-detail*
		  (format t "~%~a ~a ~a" wts (if (plusp error) "+" "-") input))
	      (setq wts (new-weights wts input error))
	      (if *show-perceptron-detail* (format t " --> ~a" wts))
	      (setf change-made t)
	      (incf wt-changes)
	      )))
	(incf passes)
	(unless change-made (return (format t "~%Converged after ")))
        (if (>= passes max-passes)
	    (return (format t "~%Not converged after ")))))
    (setf (nth i *weights*) wts)	;update global data structure
    (format t "~a " passes)
    (format t (if (= passes 1) "pass" "passes"))
    (format t " through data.")
    (format t "~%~a weight change~:p made." wt-changes)
    ))

(defun make-list-of-lists (nlists nelts-per-list elt)
  (when (plusp nlists)
    (cons (make-list nelts-per-list :initial-element elt)
	  (make-list-of-lists (1- nlists) nelts-per-list elt))))

(defun new-weights (wts input error)
  (mapcar (if (plusp error) #'+ #'-) wts input))

(defun threshold-logic (wts input)
  (threshold (dot-prod wts input)))

(defun threshold (x) (if (plusp x) 1 0))