Skip to content
Snippets Groups Projects
SaltMarsh.hs 11.9 KiB
Newer Older
{-# LANGUAGE BangPatterns              #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE GADTs                     #-}
{-# LANGUAGE StandaloneDeriving        #-}
{-# LANGUAGE TypeFamilies              #-}
{-# LANGUAGE TypeOperators             #-}
{-# LANGUAGE ViewPatterns              #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

{-
 - Salt marsh simulation, code ported to Haskell/Accelerate from Johan van der
 - Koppel's Python & OpenCL implementation
 -}

module SaltMarsh where

-- Standard library
import           Codec.Picture
import           Prelude                                  as P
import           System.Console.AsciiProgress
import           System.Environment
import           Text.Printf
import           Text.Read                                ( readMaybe )

-- Accelerate
import           Data.Array.Accelerate
import           Data.Array.Accelerate.LLVM.Native
import           Data.Array.Accelerate.System.Random.MWC
import qualified Data.Array.Accelerate          as A

-- Module
import           SaltMarsh.Constants
import           SaltMarsh.Derived

type DataPoint =
  ( Float -- u
  , Float -- v
  , Float -- h
  , Float -- s
  , Float -- d
  )

--Precomputation
h_homo = h0
u_homo = sqrt slope * h_homo**(2.0 / 3.0) / nn -- Balance between downslope acceleration and friction
s_homo = s_in * u_homo * u_homo * (h0 - hCrit)/(qs + h0 - hCrit) / e0

difco :: Exp Float -> Exp Float -> Exp Float -> Exp Float
difco d1 d2 column  = d0x * ds
  where
    dm = (d1 + d2) / 2
    pDx :: Exp Float
    pDx | gradient_pD = exp(( (log pDmax - log pDmin) * column / grid_Width + log pDmin) /
                        (grid_Width + log pDmin))
        | otherwise   = pD
    ds :: Exp Float
    ds = exp (-dm * pDx / kk)

    d0x | gradient_D0 = exp( ((log d0max)-(log d0min))* column/grid_Width + log d0min)
        | otherwise   = d0

-- projections named so they correspond to the array names in the OpenCL code
u, v, s ,d,h :: (Exp DataPoint) -> Exp Float
u tupleExp = let (x,_,_,_,_) = (unlift tupleExp) :: (Exp Float,Exp Float,Exp Float,Exp Float,Exp Float) in x
v tupleExp = let (_,x,_,_,_) = (unlift tupleExp) :: (Exp Float,Exp Float,Exp Float,Exp Float,Exp Float) in x
h tupleExp = let (_,_,x,_,_) = (unlift tupleExp) :: (Exp Float,Exp Float,Exp Float,Exp Float,Exp Float) in x
s tupleExp = let (_,_,_,x,_) = (unlift tupleExp) :: (Exp Float,Exp Float,Exp Float,Exp Float,Exp Float) in x
d tupleExp = let (_,_,_,_,x) = (unlift tupleExp) :: (Exp Float,Exp Float,Exp Float,Exp Float,Exp Float) in x


-- 'b' never seems to change, so in contrast to the OpenCL code, it's not returned as result
-- column info bundled with b, as we only have stencil2 predefined, not stencil3
-- is it really necessary for the simulation to know the column number?
simulateUV :: Stencil3x3 (Float, Float) -> Stencil3x3 DataPoint -> Exp DataPoint
simulateUV ((_     , cbtop,  _      ),
            (cbleft, cbcurr, cbright),
            (_,      cbbot,  _      ))

           ((_,       top,   _      ),
            (left,    curr, right   ),
            (_,       bot,  _       ))
  = lift (u_new, v_new, h_new, s_new, d_new)
 where

    -- bundling the columns and b together is ugly - do we need to cols?
    column = A.fst cbcurr
    bleft  = A.snd cbleft
    bright = A.snd cbright
    bbot   = A.snd cbbot
    btop   = A.snd cbtop

    hvx :: Exp Float
    hvx | gradient_Hv = exp ((log hvmax - log hvmin) *  A.fst cbcurr / grid_Width + log hvmin )
        | otherwise   = hv

    uabs     = sqrt(u curr * u curr + v curr * v curr)
    h_curr   = A.max (h curr) hCrit

    ct :: Exp Float
    ct   = sqrt (1.0/(1.0/(cb * cb) + 1.0 / (2.0 * g) * cd * d curr * hvx))  +
           sqrt g / kv * log (A.max (h_curr / hvx) 1.0)

    du = -g * dO_dx
             - u curr * d_dx u
             - v curr * d_dy u
             - g / (ct * ct) * uabs * u curr / h_curr
             + difU * d2_dxy2 u

    dv = -g * dO_dy
             - u curr * d_dx v
             - v curr * d_dy v
             - g / (ct * ct) * uabs * v curr / h_curr
             + difU * (d2_dxy2 v)

    u_new = u curr + du * dT
    v_new = v curr + dv * dT

    d_dy pop = (pop bot - pop top)/ 2.0 / dY

    d_dx z   = (z right - z left)/ 2.0 / dX

    dO_dx = ((h right + s right + bright)  -
             (h left  + s left  + bleft )) / 2.0 / dX

    dO_dy = ((h bot + s bot + bbot) -
             (h top + s top + btop) ) / 2.0 / dY


    d2_dxy2 z = (z left + z right - 2.0 * z curr) / dX / dX +
                (z top  + z bot   - 2.0 * z curr) / dY / dY

    phi = 1.0


    -- other state variables
    dh  = - d_uh_dx - d_vh_dy
    drh | gradient_Hin = exp((log hInMax -log hInMin) *
                              (A.fst cbcurr) /grid_Width +
                               log hInMin)
        | otherwise    = hIn


    h_eff = h_curr - hCrit

    ds = s_in * h_eff / (qs + h_eff)
             - e0 * (1 - pE * d curr/kk) * s curr * uabs * uabs * (g / (ct*ct))
             + (d2_dxy2_S s d)

    difDx
      | gradient_DifD = exp((log difDmax - log difDmin) *
                             A.fst cbcurr / grid_Width +
                             log difDmin)
      | otherwise     = difD

    dD = rr * d curr * (1 - d curr / kk) * qq / (qq + h_eff) -
         ec * d curr * uabs * uabs * (g/(ct * ct)) + difDx * d2_dxy2 d

    h_new = h_curr + dT * (dh + drh)
    s_new = s curr + dT * ds * phi
    d_new = d curr + dT * dD * phi

    d_uh_dx = (u right * h right - u left * h left) / 2.0 / dX
    d_vh_dy = (v bot   * h bot   - v top  * h top)  / 2.0 / dY

    d2_dxy2_S s d = difco (d right) (d curr) column / dX / dX * (s right - s curr) -
                    difco (d curr)  (d left) column / dX / dX * (s curr  - s left) +
                    difco (d bot)   (d curr) column / dY / dY * (s bot   - s curr) -
                    difco (d curr)  (d top)  column / dY / dY * (s curr  - s top )







-- after each step, we adjust the boundaries (still not exactly the same result as
-- doing the boundaries properly straight away
simulationStep ::  Acc (Array DIM2 DataPoint) -> Acc (Array DIM2 DataPoint)
simulationStep arr = fixBoundaries $ stencil2 simulateUV clamp b clamp arr
 where
   fixBoundaries ::  Acc (Array DIM2 DataPoint) -> Acc (Array DIM2 DataPoint)
   fixBoundaries arr = generate (shape arr) boundaryFn
     where
       (uarr, varr, harr, sarr, darr) = A.unzip5 arr
       boundaryFn :: (Exp DIM2 -> Exp DataPoint)
       boundaryFn i =
        cond (r A.== 0)
          (lift $ ( 2 * uarr A.! row1 - uarr A.! row2 :: Exp Float
                  , 2 * varr A.! row1 - varr A.! row2 :: Exp Float
                  , harr A.! neumannTop               :: Exp Float
                  , 0                                 :: Exp Float
                  , darr A.! neumannTop               :: Exp Float
                  )
          )(
        cond (r A.== grid_Height - 1)
           (lift $ (   uarr A.! neumannBot
                   , - varr A.! neumannBot
                   , harr A.! neumannBot
                   , sarr A.! neumannBot
                   , darr A.! neumannBot
                   )
           )(
        cond (c A.== 0)
           (lift $ (-uarr A.! neumannLeft
                   , varr A.! neumannLeft
                   , harr A.! neumannLeft
                   , sarr A.! neumannLeft
                   , darr A.! neumannLeft
                   )
           )(
        cond (c A.== grid_Width - 1)
            (lift $ (-uarr A.! neumannRight
                    , varr A.! neumannRight
                    , harr A.! neumannRight
                    , sarr A.! neumannRight
                    , darr A.! neumannRight
                    )
           )(
        lift $ ( uarr A.! i
               , varr A.! i
               , harr A.! i
               , sarr A.! i
               , darr A.! i
               )
           ))))
        where
           (_ :. r :. c) = unlift i :: (Exp Z :.  Exp Int :. Exp Int)
           row1 :: Exp DIM2
           row1 = lift (Z :. (r+1) :. (0 :: Exp Int))
           row2 :: Exp DIM2
           row2 = lift (Z :. (r+2) :. (0 :: Exp Int))
           neumannTop   = lift (Z :. (1               :: Exp Int) :. c)
           neumannBot   = lift (Z :. (grid_Height - 2 :: Exp Int) :. c)
           neumannLeft  = lift (Z :. r :. (1              :: Exp Int))
           neumannRight = lift (Z :. r :. (grid_Width - 2 :: Exp Int))

b ::  Acc (Array DIM2 (Float, Float))
b = generate (lift $ Z :.  (grid_Height :: Int) :.
                           (grid_Width  :: Int)) bfn
   where
    bfn :: Exp DIM2 -> Exp (Float, Float)
    bfn i = lift (A.fromIntegral c :: Exp Float, (1.0 - (A.fromIntegral c)/grid_Width) * slope *lengthX :: Exp Float)
       where
         (_ :. r :. c) = unlift i :: (Exp Z :.  Exp Int :. Exp Int)


whileCalcFrame :: Acc  (Array DIM0 Int, Array DIM2 DataPoint) ->  Acc (Array DIM0 Int, Array DIM2 DataPoint)
whileCalcFrame initState = awhile cont next initState
  where
    cont :: Acc (Array DIM0 Int, Array DIM2 DataPoint) -> Acc (Scalar Bool)
    cont st = unit ((afst st) ! (lift Z) A.>  0)

    next :: Acc (Array DIM0 Int, Array DIM2 DataPoint) -> Acc (Array DIM0 Int, Array DIM2 DataPoint)
    next state = lift (cnt', simulationStep arr)
      where arr  = asnd state
            cnt' = A.map (+(-1)) (afst state)



main :: IO ()
main = do
  -- args <- getArgs
  -- let steps = if P.null args
  --               then 10
  --               else case (readMaybe $ head args)  of
  --                      Just s   ->  s
  --                      _        -> 10

  -- set up initial conditions and run loop
  ds     <- randomArray rand (Z :. grid_Height :. grid_Width)
  let dp0 = runN dataPoints ds
      go  = runN whileCalcFrame

      loop :: ProgressBar -> Int -> Matrix DataPoint -> IO ()
      loop !pg !i !dp
        | i P.>= numFrames  = complete pg -- strictly speaking should not be necessary
        | otherwise         = do
            let -- compute next frame
                (_, !dp') = go (fromList Z [stepsPerFrame], dp)

            -- let fl = P.map (\(_,_,_,s,_) -> s) $ A.toList dp'
            -- writeFile "saltMarsh.txt" $ show fl

            -- or, to write png:
            writePng (printf "output/sediment-%04d.png"  i) $ heatMap grid_Width grid_Height sedimentSelector  dp'
            writePng (printf "output/waterflow-%04d.png" i) $ heatMap grid_Width grid_Height waterflowSelector dp'

            tick pg
            loop pg (i+1) dp'

  -- run simulation
  displayConsoleRegions $ do
    printf "Running saltMarsh simulation with %d frames of %d steps each\n" (numFrames::Int) stepsPerFrame
    pg    <- newProgressBar def { pgTotal        = numFrames
                                , pgOnCompletion = Just "Complete after :elapsed seconds"
                                }
    loop pg 0 dp0

  where
    stepsPerFrame :: Int
    stepsPerFrame = P.floor $ endTime / (dT  *  numFrames)

    dataPoints :: Acc (Array DIM2 Float) -> Acc (Array DIM2 DataPoint)
    dataPoints = A.map f
      where
        f :: Exp Float -> Exp DataPoint
        f x = lift ( u_homo      :: Exp Float
                   , 0.0         :: Exp Float
                   , h_homo      :: Exp Float
                   , s_homo      :: Exp Float
                   , x
                   )

    sedimentSelector  = \(_,_,_,s,_) -> s
    waterflowSelector = \(u,v,_,_,_) -> sqrt (u*u + v*v)

    -- simple heat map assuming values between 0 and 0.45
    heatMap :: Int -> Int -> (DataPoint -> Float) -> Array DIM2 DataPoint -> Image PixelRGB8
    heatMap width height sel values = generateImage heatPixel 512 512
       where
         heatPixel x y = col $ val y x
         val x  y  = P.min (P.floor $ (sel $ indexArray values (Z :. x :. y)) * 1133) 510
         col v     = PixelRGB8
                       (P.fromIntegral $ P.max 0 (v - 255))
                       (P.fromIntegral $ P.min v 255)
                       (P.fromIntegral $ 255 - (P.min 255 v))

    rand :: DIM2 :~> Float
    rand ix gen = do
      v     <- uniformR (0,1 :: Float) ix gen
      return $ if v P.< 0.02 then 1 else 0