-- Copyright (c) 2017 Utrecht University
-- Author: Koen Wermer

-- Helper functions for the Java data structure
module Javawlp.Engine.HelperFunctions where

import Javawlp.Engine.Folds 

import Language.Java.Syntax
import Language.Java.Pretty
import Data.Maybe
import Data.List
import System.IO.Unsafe
import Data.IORef
import Debug.Trace


type TypeEnv    = [(Name, Type)]
type CallCount  = [(Ident, Int)]

prettyprintTypeEnv :: TypeEnv -> String
prettyprintTypeEnv env = concat $ intersperse "\n" $ map show env

-- | Retrieves the type from the environment
lookupType :: [TypeDecl] -> TypeEnv -> Name -> Type
lookupType decls env (Name ((Ident s@('$':_)) : idents)) = getFieldType decls (getReturnVarType decls s) (Name idents) -- Names starting with a '$' symbol are generated and represent the return variable of a function
lookupType decls env (Name ((Ident s@('#':_)) : idents)) = PrimType undefined -- Names starting with a '#' symbol are generated and represent a variable introduced by handling operators
lookupType decls env (Name idents) = case lookup (Name [head idents]) env of
                                        Just t  -> getFieldType decls t (Name (tail idents))
                                        Nothing -> PrimType IntT -- For now we assume library variables to be ints      error ("can't find type of " ++ prettyPrint (Name idents) ++ "\r\n TypeEnv: " ++ show env)
                                        
-- | Gets the type of a field of an object of given type
getFieldType :: [TypeDecl] -> Type -> Name -> Type
getFieldType _ t (Name []) = t
getFieldType _ _ (Name [Ident "length"]) = PrimType IntT
getFieldType decls (RefType (ClassRefType t)) (Name (f:fs)) = getFieldType decls (getFieldTypeFromClassDecl (getDecl t decls) f) (Name fs)
    where
        getFieldTypeFromClassDecl :: ClassDecl -> Ident -> Type
        getFieldTypeFromClassDecl (ClassDecl _ _ _ _ _ (ClassBody decls)) id = getFieldTypeFromMemberDecls decls id
        
        getFieldTypeFromMemberDecls :: [Decl] -> Ident -> Type
        getFieldTypeFromMemberDecls [] _ = error "getFieldTypeFromMemberDecls"
        getFieldTypeFromMemberDecls (MemberDecl (FieldDecl mods t (VarDecl varId _ : vars)) : decls) id = if getId varId == id then t else getFieldTypeFromMemberDecls (MemberDecl (FieldDecl mods t vars) : decls) id
        getFieldTypeFromMemberDecls (_ : decls) id = getFieldTypeFromMemberDecls decls id
        
         -- Gets the class declaration that matches a given type
        getDecl :: ClassType -> [TypeDecl] -> ClassDecl
        getDecl t@(ClassType [(ident, typeArgs)]) (x:xs)    = case x of
                                                                ClassTypeDecl decl@(ClassDecl _ ident' _ _ _ _) -> if ident == ident' then decl else getDecl t xs
                                                                _ -> getDecl t xs
        getDecl t _ = error ("fieldType: " ++ show t)
        
-- | Gets the type of the class in which the method is defined
getMethodClassType :: [TypeDecl] -> Ident -> Type
getMethodClassType decls id = head $ concatMap (flip getMethodTypeFromClassDecl id) decls
    where
        getMethodTypeFromClassDecl :: TypeDecl -> Ident -> [Type]
        getMethodTypeFromClassDecl (ClassTypeDecl (ClassDecl _ className _ _ _ (ClassBody decls))) id = getMethodTypeFromMemberDecls (RefType (ClassRefType (ClassType [(className , [])]))) decls id
        getMethodTypeFromClassDecl _ _ = []
        
        getMethodTypeFromMemberDecls :: Type -> [Decl] -> Ident -> [Type]
        getMethodTypeFromMemberDecls t [] _ = []
        getMethodTypeFromMemberDecls t (MemberDecl (MethodDecl _ _ _ id' _ _ _ _) : decls) id = if id' == id then [t] else getMethodTypeFromMemberDecls t decls id
        getMethodTypeFromMemberDecls t (_ : decls) id = getMethodTypeFromMemberDecls t decls id
        
-- | Adds the special variables *obj, returnValue and returnValueVar to a type environment, given the id of the method we're looking at
extendEnv :: TypeEnv -> [TypeDecl] -> Ident -> TypeEnv
extendEnv env decls methodId = case getMethodType decls methodId of
                                Nothing -> (Name [Ident "*obj"], getMethodClassType decls methodId) : env
                                Just (RefType _)  -> (Name [Ident "returnValue"], returnValueType) : (Name [Ident "returnValueVar"], returnValueType) : (Name [Ident "*obj"], getMethodClassType decls methodId) : env
                                Just t  -> (Name [Ident "returnValue"], t) : (Name [Ident "returnValueVar"], t) : (Name [Ident "*obj"], getMethodClassType decls methodId) : env

        
-- | We introduce a special type for the return value, 
returnValueType :: Type
returnValueType = RefType (ClassRefType (ClassType [(Ident "ReturnValueType", [])]))

-- | Get's the type of a generated variable
getReturnVarType :: [TypeDecl] -> String -> Type
getReturnVarType decls s = case getMethodType decls (Ident (takeWhile (/= '$') (tail s))) of
                            Nothing -> PrimType undefined -- Kind of a hack. In case of library functions, it doesn't matter what type we return.
                            Just t  -> t
        
-- Increments the call count for a given method
incrCallCount :: CallCount -> Ident -> CallCount
incrCallCount [] id             = [(id, 1)]
incrCallCount ((id', c):xs) id  = if id == id' then (id', c + 1) : xs else (id', c) : incrCallCount xs id


-- Looks up the call count for a given method
getCallCount :: CallCount -> Ident -> Int
getCallCount [] id             = 0
getCallCount ((id', c):xs) id  = if id == id' then c else getCallCount xs id
        
getId :: VarDeclId -> Ident
getId (VarId id) = id
getId (VarDeclArray id) = getId id

fromName :: Name -> [Ident]
fromName (Name name) = name

-- Gets the ident of the method from a name
getMethodId :: Name -> Ident
getMethodId = last . fromName

-- Gets the statement(-block) defining a method
getMethod :: [TypeDecl] -> Ident -> Maybe Stmt
getMethod classTypeDecls methodId = fmap (\(b, _, _) -> b) (getMethod' classTypeDecls methodId)

-- Gets the return type of a method
getMethodType :: [TypeDecl] -> Ident -> Maybe Type
getMethodType classTypeDecls methodId = getMethod' classTypeDecls methodId >>= (\(_, t, _) -> t)

-- Gets the parameter declarations of a method
getMethodParams :: [TypeDecl] -> Ident -> Maybe [FormalParam]
getMethodParams classTypeDecls methodId = fmap (\(_, _, params) -> params) (getMethod' classTypeDecls methodId)

-- Finds a method definition. This function assumes all methods are named differently
getMethod' :: [TypeDecl] -> Ident -> Maybe (Stmt, Maybe Type, [FormalParam])
getMethod' classTypeDecls methodId = case (concatMap searchClass classTypeDecls) of
                                        r:_     -> Just r
                                        []      -> Nothing -- Library function call
  where
    searchClass (ClassTypeDecl (ClassDecl _ _ _ _ _ (ClassBody decls))) = searchDecls decls
    searchClass _ = []
    searchDecls (MemberDecl (MethodDecl _ _ t id params _ _ (MethodBody (Just b))):_) | methodId == id = [(StmtBlock b, t, params)]
    searchDecls (MemberDecl (ConstructorDecl _ _ id params _ (ConstructorBody _ b)):_) | methodId == toConstrId id = [(StmtBlock (Block b), Just (RefType (ClassRefType (ClassType [(id, [])]))), params)]
    searchDecls (_:decls) = searchDecls decls
    searchDecls [] = []
    -- Adds a '#' to indicate the id refers to a constructor method
    toConstrId (Ident s) = Ident ('#' : s)

-- Gets the statement(-block) defining the main method
getMainMethod :: [TypeDecl] -> Stmt
getMainMethod classTypeDecls = fromJust' "getMainMethod" $ getMethod classTypeDecls (Ident "main")

-- Gets a list of all method Idents (except constructor methods)
getMethodIds :: [TypeDecl] -> [Ident]
getMethodIds classTypeDecls = concatMap searchClass classTypeDecls where
    searchClass (ClassTypeDecl (ClassDecl _ _ _ _ _ (ClassBody decls))) = searchDecls decls
    searchClass _ = []
    searchDecls (MemberDecl (MethodDecl _ _ _ id _ _ _ _):decls) = id : searchDecls decls
    searchDecls (_:decls) = searchDecls decls
    searchDecls [] = []

-- Gets the class declarations
getDecls :: CompilationUnit -> [TypeDecl]
getDecls (CompilationUnit _ _ classTypeDecls) = classTypeDecls
    
-- Checks if the var is introduced. Introduced variable names start with '$' voor return variables of methods and '#' for other variables
isIntroducedVar :: Name -> Bool
isIntroducedVar (Name (Ident ('#':_): _)) = True
isIntroducedVar (Name (Ident ('$':_): _)) = True
isIntroducedVar _ = False

-- Gets the variable that represents the return value of an invocation of a method
getReturnVar :: MethodInvocation -> Ident
getReturnVar invocation = Ident (name ++  "___retval" ++ show (getIncrVarMethodInvokesCount methodid))
   where
   methodid@(Ident name) = invocationToId invocation

-- Gets the method Id from an invocation
invocationToId :: MethodInvocation -> Ident
invocationToId (MethodCall name _) = getMethodId name
invocationToId (PrimaryMethodCall _ _ id _) = id
invocationToId _ = undefined

-- Gets the type of the elements of an array. Recursion is needed in the case of multiple dimensional arrays
arrayContentType :: Type -> Type
arrayContentType (RefType (ArrayType t)) = arrayContentType t
arrayContentType t = t

-- Gets a new unique variable
getVar :: Ident
getVar = Ident ('#' : show (getIncrVarPointer ()))
    
-- Gets multiple new unique variables
getVars :: Int -> [Ident]
getVars 0 = []
getVars n = Ident ('#' : show (getIncrVarPointer ())) : getVars (n-1)

-- The number of new variables introduced ; also used to assign new variable names
varPointer :: IORef Int
varPointer = unsafePerformIO $ newIORef 0

resetVarPointer = do { writeIORef varPointer 0 ; readIORef varPointer }

-- | Gets the current var-pointer and increases the pointer by 1. 
-- Don't drop the dummy () argument; this is to force a re-evaluation of the expression. Otherwise we will
-- get the same integer every time.       
getIncrVarPointer :: () -> Int
getIncrVarPointer () = unsafePerformIO $ do
        p <- readIORef varPointer
        writeIORef varPointer (p + 1)
        return p    
    
-- | To keep track of the number of times each method is invoked; also used to assign unique return-value 
-- name to each method invocation.    
varMethodInvokesCount :: IORef CallCount 
varMethodInvokesCount =  unsafePerformIO $ newIORef []  
    
resetVarMethodInvokesCount = do { writeIORef varMethodInvokesCount [] ; readIORef varMethodInvokesCount }

getIncrVarMethodInvokesCount :: Ident -> Int
getIncrVarMethodInvokesCount methodid = unsafePerformIO $ do
        callcount <- readIORef varMethodInvokesCount
        let callcount' = incrCallCount callcount methodid
        let k = getCallCount callcount' methodid
        writeIORef varMethodInvokesCount callcount'
        return k   
    
-- Used for debugging
fromJust' :: String -> Maybe a -> a
fromJust' s ma = case ma of
                    Nothing -> error s
                    Just x  -> x
        
true :: Exp
true = Lit (Boolean True)

false :: Exp
false = Lit (Boolean False)
    
    
-- Logical operators for expressions:
(&*) :: Exp -> Exp -> Exp
e1 &* e2 = BinOp e1 CAnd e2

(|*) :: Exp -> Exp -> Exp
e1 |* e2 = BinOp e1 COr e2

neg :: Exp -> Exp
neg = PreNot

imp :: Exp -> Exp -> Exp
e1 `imp` e2 =  neg e1 |* e2

(==*) :: Exp -> Exp -> Exp
e1 ==* e2 = BinOp e1 Equal e2

(/=*) :: Exp -> Exp -> Exp
e1 /=* e2 = neg (e1 ==* e2)

-- Gets the value from an array
arrayAccess :: Exp -> [Exp] -> Exp
arrayAccess a i = case a of
                    ArrayCreate t exps dim          -> getInitValue t
                    ArrayCreateInit t dim arrayInit -> getInitValue t
                    _                               -> ArrayAccess (ArrayIndex a i)

-- Accesses fields of fields
fieldAccess :: Exp -> Name -> FieldAccess
fieldAccess e (Name [id])       = PrimaryFieldAccess e id
fieldAccess e (Name (id:ids))   = fieldAccess (FieldAccess (PrimaryFieldAccess e id)) (Name ids)
fieldAccess _ _ = error "FieldAccess without field name"
                    
-- | Gets the initial value for a given type
getInitValue :: Type -> Exp
getInitValue (PrimType t) = case t of
                                BooleanT -> Lit (Boolean False)
                                ByteT -> Lit (Word 0)
                                ShortT -> Lit (Int 0)
                                IntT -> Lit (Int 0)
                                LongT -> Lit (Int 0)
                                CharT -> Lit (Char '\NUL')
                                FloatT -> Lit (Float 0)
                                DoubleT -> Lit (Double 0)
getInitValue (RefType t)  = Lit Null

-- counting expressing complexity, which is the number of logical operators in the expression
exprComplexity :: Exp -> Int
exprComplexity e = case e of
    -- recursions
    PreNot e        -> exprComplexity e + 1
    BinOp e1 op e2  ->
        let
        k1 = exprComplexity e1
        k2 = exprComplexity e2
        in
        case op of
         And  -> k1 + k2 + 1
         Or   -> k1 + k2 + 1
         Xor  -> k1 + k2 + 1
         CAnd -> k1 + k2 + 1
         COr  -> k1 + k2 + 1
         _    -> 1
    Cond g e1 e2   -> exprComplexity e1 + exprComplexity e2 + exprComplexity g + 1         
    -- other cases are 0
    _              ->  0
    


findInfix [] s = Just 0
findInfix u s  = find (>=0) [ if u `isPrefixOf` z then i else -1| (i,z) <- tails_ ]
    where
    tails_ = [ (getStartPosition v, getString v) | v <- tails (zip [0..] s) ]
    getString v = map snd v
    getStartPosition [] = -1
    getStartPosition  v = head (map fst v)
    
getCallCount__ (ExpName (Name [Ident s])) = case k of
    Just pos -> read $ drop (pos + length key) s
    _        -> -1
    where
    key = "___retval"
    k = findInfix key s  
        
getCallCount__ _ = -1

adjustCallCount__ (ExpName (Name [Ident s])) newCount = (ExpName (Name [Ident (s0 ++ key ++ show newCount)])) 
    where
    key = "___retval"
    k = fromJust $ findInfix key s
    s0 = take k s

getMethodId__ (ExpName (Name [Ident s])) = Ident (take k s)
    where
    key = "___retval"
    k = fromJust $ findInfix key s

normalizeInvocationNumbers :: Exp -> Exp
normalizeInvocationNumbers e = normalize e
   where
   calls_ :: [(Ident,[Int])]       
   calls_ = countCallVars [] e
   
   callscount :: CallCount
   callscount = [ (i,length (nub instances)) | (i,instances) <- calls_ ]
   
     
   normalizeName e@(ExpName (Name [ident])) = if k>=0 then adjustCallCount__ e (k `mod` m) else e
        where
        k = getCallCount__ e
        m = getCallCount callscount (getMethodId__ e)
        
   normalizeName e = e
   
       
   normalize e = case e of 
       -- base case
       ename@(ExpName _) -> normalizeName ename
       
       -- recursions
       PrePlus e -> PrePlus (normalize e)
       PreMinus e -> PreMinus (normalize e)
       PreBitCompl e -> PreBitCompl (normalize e)
       PreNot e -> PreNot (normalize e)
       Cast t e -> Cast t (normalize e)
       BinOp e1 op e2 -> BinOp (normalize e1) op (normalize e2)
       Cond g e1 e2 -> Cond (normalize g) (normalize e1) (normalize e2)
       
       -- left undefined:
       Lambda lParams lExp -> undefined
       InstanceOf e refType -> undefined
       
       -- unchanged:
       e' -> e'
       
       
   registerInvocation (methodid,k) calls = worker calls
      where 
      worker [] =  [ (methodid, [k]) ]
      worker ((mx,z):therest)  
          | methodid == mx = (mx, k:z) : therest
          | otherwise      = (mx,z) : worker therest
          
       
   countCallVars calls e = case e of
       -- base case
       ename@(ExpName n) -> 
          let
          k = getCallCount__ ename 
          methodid = getMethodId__ ename
          in
          if k >= 0 then registerInvocation (methodid,k) calls
                    else calls
 
       -- recursions
       PrePlus e      -> countCallVars calls e
       PreMinus e     -> countCallVars calls e
       PreBitCompl e  -> countCallVars calls e
       PreNot e       -> countCallVars calls e
       Cast t e       -> countCallVars calls e
       BinOp e1 op e2 ->  countCallVars (countCallVars calls e1) e2
       Cond g e1 e2   -> countCallVars (countCallVars (countCallVars calls g) e1) e2
 
       -- left undefined:
       Lambda lParams lExp  -> undefined
       InstanceOf e refType -> undefined
 
       -- ignored (should not contian method calls):
       e' -> calls