(load-file "util.ath")

;; The following method takes two terms t1 and t2 where
;; the equality (= t1 t2) is in the assumption base, and
;; derives the theorem (= t2 t1).

(define (get-left t)
  (match t 
    (((some-symbol f) s _) s)))

(define (get-right t)
  (match t 
    (((some-symbol f) _ t) t)))

(define eq-symmetry
  (method (t1 t2) 
    (dlet ((t1=t2 $(= t1 t2))
           (v (fresh-var [t1 t2]))
           (property (= v t1))
           (t1=t1<==>t2=t1 (!leibniz t1 t2 property v))
           (t1=t1==>t2=t1 (!left-iff t1=t1<==>t2=t1))
           (t1=t1 (!eq-reflex t1)))
      (!mp t1=t1==>t2=t1 t1=t1))))


(define (eq-symmetry2 P)
  (dmatch P
    ((= s t) (!eq-symmetry s t))))

;; This method takes three terms t1, t2, and t3 such that
;; t1 = t2 and t2 = t3 hold, and derives the equality t1 = t3.

(define eq-tran
  (method (t1 t2 t3)
    (dlet ((t1=t2 $(= t1 t2))
           (t2=t3 $(= t2 t3))
           (v (fresh-var [t1 t2 t3]))
           (equality (= v t3))
           (t1=t3<==>t2=t3 (!leibniz t1 t2 equality v))
           (t2=t3==>t1=t3 (!right-iff t1=t3<==>t2=t3)))
      (!mp t2=t3==>t1=t3 t2=t3))))

(define (eq-transitivity eq1 eq2)
  (dmatch [eq1 eq2]
    ([(= t1 t2) (= t2 t3)] (!eq-tran t1 t2 t3))))
           

;; The method below takes a term t1, a theorem P, and a term t2,
;; where the equality (= t1 t2) holds, and returns the proposition
;; obtained from P by replacing every occurence of t1 by t2.

(define (replace-term-in-term t1 t t2)
  (match t 
     ((val-of t1) t2)
     (((some-symbol f) (some-list args)) 
         (make-term f (map (function (t) (replace-term-in-term t1 t t2)) args)))
     (s s)))

(define (replace-term-in-prop t1 P t2)
  (match P
    ((some-atom t) (replace-term-in-term t1 t t2))
    ((not Q) (not (replace-term-in-prop t1 Q t2)))
    (((some-prop-con pc) P1 P2) (pc (replace-term-in-prop t1 P1 t2)
                                    (replace-term-in-prop t1 P2 t2)))
    (((some-quant quant) v B) (quant v (replace-term-in-prop t1 B t2)))))


(define (substitute-equals t1 P t2)
  (dlet ((fv (fresh-var))
         (newP (replace-term-in-prop t1 (rename P) fv))
         (biconditional (!leibniz t1 t2 newP fv)))
    (!mp (!left-iff biconditional) P)))

;; The following is a more selective, positional version of
;; substitute-equals. It takes a term t1, a theorem P, a position
;; pos (represented as a list of numeric terms, say [2 1 4]) and
;; a term t2, where the equality (= t1 t2) must hold. It returns
;; the proposition obtained from P by replacing the occurence of 
;; t1 in P at position pos by t2. 

(define (pos-substitute-equals t1 P pos t2)
  (dlet ((t1=t2 $(= t1 t2))
         (v  (fresh-var [t1 t2 P]))
         (newP (prop-pos-replace P pos v))
         (biconditional (!leibniz t1 t2 newP v)))
    (!mp (!left-iff biconditional) P)))


;; The method eq-congruence takes two terms t1 and t2, where
;; the equality (= t1 t2) must hold, a term t, and a variable v
;; and returns the equality (= t1' t2'), where t1' is obtained from
;; t by replacing every occurence of v by t1, and t2' is obtained
;; from t by replacing every occurence of v by t2.

(define eq-congruence
  (method (t1 t2 t v)
    (dlet ((t1=t2 $(= t1 t2))
           (v' (fresh-var))
           (newt (replace-var v v' t))
           (newt{t2/v'} (replace-var v' t2 newt))
           (prop (= newt newt{t2/v'}))
           (newt{t1/v'}=newt{t2/v'}<==>newt{t2/v'}=newt{t2/v'} (!leibniz t1 t2 prop v'))
           (newt{t2/v'}=newt{t2/v'}==>newt{t1/v'}=newt{t2/v'}
                (!right-iff newt{t1/v'}=newt{t2/v'}<==>newt{t2/v'}=newt{t2/v'}))
           (newt{t2/v'}=newt{t2/v'} (!eq-reflex newt{t2/v'})))
      (!mp newt{t2/v'}=newt{t2/v'}==>newt{t1/v'}=newt{t2/v'}
           newt{t2/v'}=newt{t2/v'}))))

;; The following method, positional congruence, works with positions 
;; instead of variables. It takes again two terms t1 and t2 such that
;; (= t1 t2) is a theorem, a term t, and a position pos, and returns
;; the equality (= t1' t2'), where t1' is obtained from t by replacing
;; plugging t1 at position pos, and t2' is obtained from t by plugging
;; t2 at position pos.

(define (pos-congruence t1 t2 t pos)
  (dlet ((v (fresh-var [t1 t2 t]))
         (newt (term-replace t pos v)))
    (!eq-congruence t1 t2 newt v)))

;;==============================================================================
;;                            FUNCTION CONGRUENCE 
;;==============================================================================

;; The method fun-cong below takes a function symbol f (of arbitrary arity), 
;; and two lists of terms, s-terms = [s1 ... sn] and t-terms = [t1 ... tn],
;; such that s_i = t_i is in the assumption base for every i = 1,...,n,
;; and derives the equality f(s1,...sn) = f(t1,...,tn).  

(define (fun-cong f s-terms t-terms)
  (dletrec ((v (fresh-var (join (get-all-vars s-terms) (get-all-vars t-terms))))
            (do-args (method (first-s_i first-t_i rem-s_j rem-t_j eq)
                        (dmatch [rem-s_j rem-t_j]
                          ([[] []] (!claim eq))
                          ([(list-of s_j more-s_j) (list-of t_j more-t_j)]
                              (dlet ((F (= (make-term f s-terms)
                                           (make-term f (join first-t_i (join [v] more-s_j)))))
                                     (bi-cond (!leibniz s_j t_j F v))
                                     (new-eq (!mp (!left-iff bi-cond) eq)))
                                       (!do-args (join first-s_i [s_j]) 
                                                 (join first-t_i [t_j])
                                                 more-s_j more-t_j new-eq)))))))
     (!do-args [] [] s-terms t-terms (!eq-reflex (make-term f s-terms)))))
                                             

;;==============================================================================
;;                            RELATION CONGRUENCE 
;;==============================================================================

;; The method rel-cong below takes an atomic theorem P of the form (R s1 ... sn),
;; the terms s-terms [s1...sn], and the terms t-terms [t1...tn], where si = ti 
;; must be in the assumption base for all i, and returns the theorem (R t1 ... tn).

(define (rel-cong P s-terms t-terms)
  (dletrec ((do-args (method (s-terms t-terms theorem)
                       (dmatch [s-terms t-terms]
                         ([[] []] (!claim theorem))
                         ([(list-of s more-s) (list-of t more-t)]
                             (dlet ((new-theorem (!substitute-equals s theorem t)))
                               (!do-args more-s more-t new-theorem)))))))
     (!do-args s-terms t-terms P)))

; (declare R (-> (Number Number) Boolean))

; (assert (R ?a1 ?a2))

; (assert (= ?a1 ?b1) (= ?a2 ?b2))

; (!rel-cong (R ?a1 ?a2) [?a1 ?a2] [?b1 ?b2])

;; The argument list s-terms in the method rel-cong above is superfluous, since
;; it can be extracted from the atomic theorem P. Hence rel-cong-2 below simply
;; takes a theorem P, which again must be of the form (R s1 ... sn), and a list
;; of terms t-terms [t1 ... tn], where si = ti must be in the asm. base, and 
;; derives the theorem (R t1 ... tn)

(primitive-method (prim-rel-cong P Q)
  (check ((for-each (zip (children P) (children Q))
		    (function (tp) (match tp
				     ([s t] (holds? (= s t)))))) Q)))

           


(define (rel-cong-2 P t-terms)
  (dcheck ((atom? P) (!rel-cong P (children P) t-terms))))

(define (rel-cong3 P Q)
  (dcheck ((atom? P) (!rel-cong P (children P) (children Q)))))

;;================================================================================
;;                              RECURSIVE CONGRUENCE
;;================================================================================

;; This is a powerful recursive congruence method. If any
;; subterms of t1 and t2 in corresponding positions are equal
;; (with the equality in the assumption base), everything else 
;; being the same, then the theorem (= t1 t2) is returned.

(define (rec-cong t1 t2)
    (dmatch (equal? t1 t2)
      (true (!eq-reflex t1))
      (_ (dmatch (fetch (function (P)  
                          (|| (equal? P (= t1 t2))
                              (equal? P (= t2 t1)))))
           (() (dlet ((root1 (root t1))
                      (root2 (root t2))
                      (args1 (children t1))
                      (args2 (children t2)))
                 (dmatch (equal? root1 root2)
                   (true (dletrec ((do-args 
                                     (method (s-terms t-terms) 
                                       (dmatch [s-terms t-terms]
                                         ([[] []] (!fun-cong root1 args1 args2))
                                         ([(list-of s1 more) (list-of t1 rest)] 
                                           (dbegin 
                                             (!rec-cong s1 t1) 
                                             (!do-args more rest)))))))
                           (!do-args args1 args2))))))
           (P (dmatch P
                ((= (val-of t1) (val-of t2)) (!claim P))
                (_ (!eq-symmetry t2 t1))))))))
                      
;; (assume (= ?x (+ 1 2)) (!rec-cong (* 7 ?x) (* 7 (+ 1 2))))

;; (assert (= ?x ?y))

;; (define s (+ (+ 1 (+ ?x 9)) (+ ?y 88)))

;; (define t (+ (+ 1 (+ ?y 9)) (+ ?x 88)))

;; (define u (+ (* 1 (+ ?y 9)) (+ ?x 88)))

;; This should work:

;; (!rec-cong s t)

;; And this should fail: 

;; (!rec-cong s u)


;; FINISH THIS: 

; (define (replace-equals P Q)
;  (dmatch [P Q]
;    ([(not P') (not Q')] (!not-congruence (!repace-equals P' 
;   (some-prop-con op)




;(define t1 (+ ?x 1))

;(define t2 (+ ?y 2))

;(assert (= t1 t2))

;(define P (and (= ?x ?z) (forall ?x (not (= ?x (+ ?x 1))))))

;(define P (and (= (succ ?x) ?z) (forall ?x (not (= ?x (succ ?x))))))

;(!se t1 P t2)


(define (rec-rel-cong P Q)
  (!map-method (method (term-pair)
                 (dmatch term-pair
	           ([s t] (!rec-cong s t)))) 
               (zip (children P) (children Q))
	       (method (res) (!prim-rel-cong P Q))))
