{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE KindSignatures         #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE PolyKinds              #-}
{-# LANGUAGE UndecidableInstances   #-}

------------------------------------------------------------------------------
-- | This module provides new implementations for '(>>=)', '(>>)', 'pure' and
-- 'return' so that they will work simultaneously with both regular and indexed
-- monads.
--
-- Intended usage:
--
-- @@
--   {-# LANGUAGE RebindableSyntax #-}
--
--   import Language.Haskell.DoNotation
--   import Prelude hiding (Monad (..), pure)
-- @@
module Language.Haskell.DoNotation
  ( BindSyntax (..)
  , PureSyntax (..)
  , P.Monad ()
  , IxMonad ()
  ) where

import           Control.Monad.Indexed
import           Data.Coerce
import           Data.Kind (Constraint)
import           Data.Kind (Type)
import qualified Prelude as P
import           Prelude hiding (Monad (..), pure)
import           System.IO (IOMode (..))


------------------------------------------------------------------------------
-- | Typeclass that provides 'pure' and 'return'.
class PureSyntax (x :: Type -> Type) where
  pure :: a -> x a
  pure = forall (x :: * -> *) a. PureSyntax x => a -> x a
return

  return :: a -> x a
  return = forall (x :: * -> *) a. PureSyntax x => a -> x a
pure


instance {-# INCOHERENT #-}
      Applicative m => PureSyntax m where
  pure :: forall a. a -> m a
pure = forall (m :: * -> *) a. Applicative m => a -> m a
P.pure

instance (IxMonad m, j ~ i) => PureSyntax (m i j) where
  pure :: forall a. a -> m i j a
pure = forall {k} (m :: k -> k -> * -> *) a (i :: k).
IxPointed m =>
a -> m i i a
ireturn


------------------------------------------------------------------------------
-- | Typeclass that provides '(>>=)' and '(>>)'.
class BindSyntax (x :: Type -> Type)
                 (y :: Type -> Type)
                 (z :: Type -> Type)
      | x y -> z
      , x z -> y
      , y z -> x where
  (>>=) :: x a -> (a -> y b) -> z b

  (>>) :: x a -> y b -> z b
  x a
a >> y b
b = x a
a forall (x :: * -> *) (y :: * -> *) (z :: * -> *) a b.
BindSyntax x y z =>
x a -> (a -> y b) -> z b
>>= forall a b. a -> b -> a
const y b
b

instance  (P.Monad m, x ~ m) => BindSyntax m x m where
  >>= :: forall a b. m a -> (a -> x b) -> m b
(>>=) = forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
(P.>>=)

instance {-# INCOHERENT #-}
      ( IxMonad m
      , x ~ m i j
      , y ~ m j k
      , z ~ m i k
      ) => BindSyntax x y z where
  >>= :: forall a b. x a -> (a -> y b) -> z b
(>>=) = forall {k1} (m :: k1 -> k1 -> * -> *) (i :: k1) (j :: k1) a
       (k2 :: k1) b.
IxMonad m =>
m i j a -> (a -> m j k2 b) -> m i k2 b
(>>>=)