module Verifier where

import Language.Java.Syntax
import Language.Java.Pretty
import Z3.Monad
import System.IO.Unsafe

import Folds
import HelperFunctions
import Settings


-- | Checks wether the negation is unsatisfiable
isTrue :: TypeEnv -> Exp -> Z3 Bool
isTrue env e = isFalse env (PreNot e)
            
          
-- | Checks wether the expression is unsatisfiable
isFalse :: TypeEnv -> Exp -> Z3 Bool
isFalse env e = 
    do
        ast <- foldExp expAssertAlgebra e env
        assert ast
        result <- check
        solverReset
        case result of
            Unsat -> return True
            _     -> return False
        
-- | Unsafe version of isTrue
unsafeIsTrue :: TypeEnv -> Exp -> Bool
unsafeIsTrue env = unsafePerformIO . evalZ3 . isTrue env

-- | Unsafe version of isFalse
unsafeIsFalse :: TypeEnv -> Exp -> Bool
unsafeIsFalse env = unsafePerformIO . evalZ3 . isFalse env

stringToBv :: String -> Z3 AST
stringToBv [] = mkIntNum 0 >>= mkInt2bv 8
stringToBv (c:cs) = do
                        c' <- mkIntNum (fromEnum c) >>= mkInt2bv 8
                        cs' <- stringToBv cs
                        mkConcat c' cs'
                        

-- | Defines the convertion from an expression to AST so that Z3 can assert satisfiability
--   This is used to fold expressions generated by the WLP transformer, so not all valid Java expressions need to be handled
expAssertAlgebra :: ExpAlgebra (TypeEnv -> Z3 AST)
expAssertAlgebra = (fLit, fClassLit, fThis, fThisClass, fInstanceCreation, fQualInstanceCreation, fArrayCreate, fArrayCreateInit, fFieldAccess, fMethodInv, fArrayAccess, fExpName, fPostIncrement, fPostDecrement, fPreIncrement, fPreDecrement, fPrePlus, fPreMinus, fPreBitCompl, fPreNot, fCast, fBinOp, fInstanceOf, fCond, fAssign, fLambda, fMethodRef) where
    fLit lit _     = case lit of
                        Int n -> mkInteger n
                        Word n -> mkInteger n
                        Float d -> mkRealNum d
                        Double d -> mkRealNum d
                        Boolean b -> mkBool b
                        Char c -> do sort <- mkIntSort
                                     mkInt (fromEnum c) sort
                        String s -> stringToBv s
                        Null -> do sort <- mkIntSort
                                   mkInt 0 sort
    fClassLit = undefined
    fThis = undefined
    fThisClass = undefined
    fInstanceCreation = undefined
    fQualInstanceCreation = undefined
    fArrayCreate = error "ArrayCreate"
    fArrayCreateInit = undefined
    fFieldAccess fieldAccess _  = case fieldAccess of
                                    PrimaryFieldAccess e id         -> case e of
                                                                        InstanceCreation _ t args _ -> undefined
                                                                        _ -> undefined
                                    SuperFieldAccess id             -> mkStringSymbol (prettyPrint (Name [id])) >>= mkIntVar
                                    ClassFieldAccess (Name name) id -> mkStringSymbol (prettyPrint (Name (name ++ [id]))) >>= mkIntVar
    fMethodInv invocation env   = case invocation of
                                    MethodCall (Name [Ident "*length"]) [a, (Lit (Int n))] -> case a of
                                                                                                    ArrayCreate t exps dim          -> foldExp expAssertAlgebra (if fromEnum n < length exps then (exps !! fromEnum n) else Lit (Int 0)) env
                                                                                                    ArrayCreateInit t dim arrayInit -> mkInteger 0
                                                                                                    ExpName name                    -> do
                                                                                                                                        symbol <- mkStringSymbol ("*length(" ++ prettyPrint name ++ ", " ++ show n ++ ")")
                                                                                                                                        mkIntVar symbol
                                                                                                    Cond g a1 a2                    -> foldExp expAssertAlgebra (Cond g (MethodInv (MethodCall (Name [Ident "*length"]) [a1, (Lit (Int n))])) (MethodInv (MethodCall (Name [Ident "*length"]) [a2, (Lit (Int n))]))) env
                                                                                                    Lit Null                        -> mkInteger (-1)
                                                                                                    _                               -> error ("length of non-array: " ++ prettyPrint a)
                                    _ -> error (prettyPrint invocation)
    fArrayAccess arrayIndex env = case arrayIndex of
                                    ArrayIndex (ArrayCreate t _ _) _ -> foldExp expAssertAlgebra (getInitValue t) env
                                    ArrayIndex (ArrayCreateInit t _ _) _ -> foldExp expAssertAlgebra (getInitValue t) env
                                    ArrayIndex (ExpName name) i -> do
                                                                    symbol <- mkStringSymbol (prettyPrint name ++ "[" ++ show i ++ "]")
                                                                    case fmap arrayContentType (lookup name env) of
                                                                        Just (PrimType BooleanT)    -> mkBoolVar symbol
                                                                        Just (PrimType FloatT)      -> mkRealVar symbol
                                                                        Just (PrimType DoubleT)     -> mkRealVar symbol
                                                                        _                           -> mkIntVar symbol
                                    ArrayIndex (Cond g a1 a2) i -> foldExp expAssertAlgebra (Cond g (ArrayAccess (ArrayIndex a1 i)) (ArrayAccess (ArrayIndex a2 i))) env
                                    ArrayIndex e _ -> foldExp expAssertAlgebra e env
    fExpName name env            = do
                                    symbol <- mkStringSymbol (prettyPrint name)
                                    case lookup name env of
                                        Just (PrimType BooleanT)    -> mkBoolVar symbol
                                        Just (PrimType FloatT)      -> mkRealVar symbol
                                        Just (PrimType DoubleT)     -> mkRealVar symbol
                                        Just (PrimType IntT)        -> mkIntVar symbol
                                        Just (RefType _)            -> mkIntVar symbol
                                        -- For now, we assume library methods return ints. Fixing this would require type information of library methods.
                                        t                           -> if ignoreLibMethods then mkStringSymbol "libMethodCall" >>= mkIntVar else error ("Verifier: Type of " ++ prettyPrint name ++ " unknown or not implemented: " ++ show t)
    fPostIncrement = undefined
    fPostDecrement = undefined
    fPreIncrement = undefined
    fPreDecrement = undefined
    fPrePlus e env  = e env
    fPreMinus e env     = do
                            ast <- e env
                            zero <- mkInteger 0
                            mkSub [zero, ast]
    fPreBitCompl = undefined
    fPreNot e env = e env >>= mkNot
    fCast = undefined
    fBinOp e1 op e2 env = case op of
                            Mult -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkMul [ast1, ast2]
                            Div -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkDiv ast1 ast2
                            Rem -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkRem ast1 ast2
                            Add -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkAdd [ast1, ast2]
                            Sub -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkSub [ast1, ast2]
                            LShift -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkBvshl ast1 ast2
                            RShift -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkBvashr ast1 ast2
                            RRShift -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkBvlshr ast1 ast2
                            LThan -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkLt ast1 ast2
                            GThan -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkGt ast1 ast2
                            LThanE -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkLe ast1 ast2
                            GThanE -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkGe ast1 ast2
                            Equal -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkEq ast1 ast2
                            NotEq -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      eq <- mkEq ast1 ast2
                                      mkNot eq
                            And-> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkAnd [ast1, ast2]
                            Or -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkOr [ast1, ast2]
                            Xor -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkXor ast1 ast2
                            CAnd -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkAnd [ast1, ast2]
                            COr -> do
                                      ast1 <- e1 env
                                      ast2 <- e2 env
                                      mkOr [ast1, ast2]
    fInstanceOf = undefined
    fCond g e1 e2 env    = do
                            astg <- g env
                            ast1 <- e1 env
                            ast2 <- e2 env
                            mkIte astg ast1 ast2
    fAssign = undefined
    fLambda = undefined
    fMethodRef = undefined