module Verifier where

import Language.Java.Syntax
import Z3.Monad

import Folds

-- Imported for the example:
import Control.Applicative
import Control.Monad ( join )
import Data.Maybe
import qualified Data.Traversable as T



isTrue :: Exp -> Bool
isTrue = undefined

-- | Defines the convertion from an expression to AST so that Z3 can assert satisfiability
--   This is used to fold expressions generated by the WLP transformer, so not all valid Java expressions need to be handled
expAssertAlgebra :: ExpAlgebra (Z3 AST)
expAssertAlgebra = (fLit, fClassLit, fThis, fThisClass, fInstanceCreation, fQualInstanceCreation, fArrayCreate, fArrayCreateInit, fFieldAccess, fMethodInv, fArrayAccess, fExpName, fPostIncrement, fPostDecrement, fPreIncrement, fPreDecrement, fPrePlus, fPreMinus, fPreBitCompl, fPreNot, fCast, fBinOp, fInstanceOf, fCond, fAssign, fLambda, fMethodRef) where
    fLit lit       = case lit of
                        Int n -> mkInteger n
                        Word n -> undefined
                        Float d -> mkRealNum d
                        Double d -> mkRealNum d
                        Boolean b -> mkBool b
                        Char c -> do sort <- mkIntSort
                                     mkInt (fromEnum c) sort
                        String s -> undefined
                        Null -> do sort <- mkIntSort
                                   mkInt 0 sort
    fClassLit = undefined
    fThis = undefined
    fThisClass = undefined
    fInstanceCreation = undefined
    fQualInstanceCreation = undefined
    fArrayCreate = undefined
    fArrayCreateInit = undefined
    fFieldAccess = undefined
    fMethodInv = undefined
    fArrayAccess = undefined
    fExpName = undefined
    fPostIncrement = undefined
    fPostDecrement = undefined
    fPreIncrement = undefined
    fPreDecrement = undefined
    fPrePlus = undefined
    fPreMinus = undefined
    fPreBitCompl = undefined
    fPreNot = undefined
    fCast = undefined
    fBinOp e1 op e2    = case op of
                            Mult -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkMul [ast1, ast2]
                            Div -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkDiv ast1 ast2
                            Rem -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkRem ast1 ast2
                            Add	-> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkAdd [ast1, ast2]
                            Sub	-> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkSub [ast1, ast2]
                            LShift -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkBvshl ast1 ast2
                            RShift -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkBvashr ast1 ast2
                            RRShift -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkBvlshr ast1 ast2
                            LThan -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkLt ast1 ast2
                            GThan -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkGt ast1 ast2
                            LThanE -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkLe ast1 ast2
                            GThanE -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkGe ast1 ast2
                            Equal -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkEq ast1 ast2
                            NotEq -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      eq <- mkEq ast1 ast2
                                      mkNot eq
                            And-> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkAnd [ast1, ast2]
                            Or -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkOr [ast1, ast2]
                            Xor -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkXor ast1 ast2
                            CAnd -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkAnd [ast1, ast2]
                            COr -> do
                                      ast1 <- e1
                                      ast2 <- e2
                                      mkOr [ast1, ast2]
    fInstanceOf = undefined
    fCond = undefined
    fAssign = undefined
    fLambda = undefined
    fMethodRef = undefined




-- Example:
main :: IO ()
main = evalZ3With Nothing opts script >>= \mbSol ->
        case mbSol of
             Nothing  -> error "No solution found."
             Just sol -> putStr "Solution: " >> print sol
  where opts = opt "MODEL" True +? opt "MODEL_COMPLETION" True


script :: Z3 (Maybe [Integer])
script = do
  q1 <- mkFreshIntVar "q1"
  q2 <- mkFreshIntVar "q2"
  q3 <- mkFreshIntVar "q3"
  q4 <- mkFreshIntVar "q4"
  _1 <- mkInteger 1
  _4 <- mkInteger 4
  -- the ith-queen is in the ith-row.
  -- qi is the column of the ith-queen
  assert =<< mkAnd =<< T.sequence
    [ mkLe _1 q1, mkLe q1 _4  -- 1 <= q1 <= 4
    , mkLe _1 q2, mkLe q2 _4
    , mkLe _1 q3, mkLe q3 _4
    , mkLe _1 q4, mkLe q4 _4
    ]
  -- different columns
  assert =<< mkDistinct [q1,q2,q3,q4]
  -- avoid diagonal attacks
  assert =<< mkNot =<< mkOr =<< T.sequence
    [ diagonal 1 q1 q2  -- diagonal line of attack between q1 and q2
    , diagonal 2 q1 q3
    , diagonal 3 q1 q4
    , diagonal 1 q2 q3
    , diagonal 2 q2 q4
    , diagonal 1 q3 q4
    ]
  -- check and get solution
  fmap snd $ withModel $ \m ->
    catMaybes <$> mapM (evalInt m) [q1,q2,q3,q4]
  where mkAbs x = do
          _0 <- mkInteger 0
          join $ mkIte <$> mkLe _0 x <*> pure x <*> mkUnaryMinus x
        diagonal d c c' =
          join $ mkEq <$> (mkAbs =<< mkSub [c',c]) <*> (mkInteger d)