module Obs.NormalForm.Normalise import Data.Bool import Data.So import Decidable.Equality import Obs.Logging import Obs.NormalForm import Obs.Sort import Obs.Substitution import Text.PrettyPrint.Prettyprinter %default total -- Aliases --------------------------------------------------------------------- public export 0 LogConstructor : Type -> Unsorted.Family Bool LogConstructor ann ctx = Logging ann (Constructor ctx) public export 0 LogNormalForm : Type -> Sorted.Family Bool LogNormalForm ann b ctx = Logging ann (NormalForm b ctx) 0 LogNormalForm' : Type -> Sorted.Family Bool LogNormalForm' ann b ctx = Either (Logging ann (NormalForm b ctx)) (Elem b ctx) -- Copied and specialised from Obs.Substitution lift : (ctx : List (b ** (String, (s ** isSet s = b)))) -> Map ctx' (LogNormalForm' ann) ctx'' -> Map (map DPair.fst ctx ++ ctx') (LogNormalForm' ann) (map DPair.fst ctx ++ ctx'') lift [] f = f lift ((s ** y) :: ctx) f = add (LogNormalForm' ann) (Left $ pure $ point y Here) (\i => bimap (\t => pure (rename !t There)) There (lift ctx f i)) -- Normalisation --------------------------------------------------------------- subst' : NormalForm ~|> Hom (LogNormalForm' ann) (LogNormalForm ann) export doApp : {b' : _} -> NormalForm b ctx -> NormalForm b' ctx -> LogNormalForm ann b ctx doApp (Ntrl t) u = pure (Ntrl $ App _ t u) doApp (Cnstr (Lambda s var t)) u = inScope "doApp" $ do trace $ pretty {ann} "substituting" <++> pretty u <+> softline <+> pretty "for \{var} in" <++> pretty t let Yes Refl = decEq b' (isSet s) | No _ => fatal "internal sort mismatch" subst' t (add (LogNormalForm' ann) (Left $ pure u) Right) doApp (Cnstr t) u = inScope "wrong constructor for apply" $ fatal t doApp Irrel u = pure Irrel export doFst : (b, b' : _) -> NormalForm (b || b') ctx -> LogNormalForm ann b ctx doFst False b' t = pure Irrel doFst True b' (Ntrl t) = pure (Ntrl $ Fst b' t) doFst True b' (Cnstr (Pair (Set _) s' prf t u)) = pure t doFst True b' (Cnstr t) = inScope "wrong constructor for fst" $ fatal t export doSnd : (b, b' : _) -> NormalForm (b || b') ctx -> LogNormalForm ann b' ctx doSnd b False t = pure Irrel doSnd b True t = let t' : NormalForm True ctx t' = rewrite sym $ orTrueTrue b in t in case t' of Ntrl t => pure (Ntrl $ Snd b t) Cnstr (Pair _ (Set _) prf t u) => pure u Cnstr t => inScope "wrong constructor for snd" $ fatal t export doIf : {b : _} -> NormalForm True ctx -> NormalForm b ctx -> NormalForm b ctx -> LogNormalForm ann b ctx doIf {b = False} t u v = pure Irrel doIf {b = True} (Ntrl t) u v = pure (Ntrl $ If t u v) doIf {b = True} (Cnstr True) u v = pure u doIf {b = True} (Cnstr False) u v = pure v doIf {b = True} (Cnstr t) u v = inScope "wrong constructor for case" $ fatal t export doAbsurd : (b : _) -> NormalForm b ctx doAbsurd False = Irrel doAbsurd True = Ntrl Absurd export doCast : (b' : _) -> (a, b : NormalForm True ctx) -> NormalForm b' ctx -> LogNormalForm ann b' ctx doCastR : (a : Constructor ctx) -> (b : NormalForm True ctx) -> NormalForm True ctx -> LogNormalForm ann True ctx doCastR a (Ntrl b) t = pure (Ntrl $ CastR a b t) doCastR (Sort _) (Cnstr (Sort _)) t = pure t doCastR ty@(Pi s s'@(Set _) var a b) (Cnstr ty'@(Pi l l' _ a' b')) t = let y : NormalForm (isSet s) (isSet s :: ctx) y = point (var, (s ** Refl)) Here in do let Yes Refl = decEq (s, s') (l, l') | No _ => pure (Ntrl $ CastStuck ty ty' t) x <- assert_total (doCast (isSet s) (weaken [isSet s] a') (weaken [isSet s] a) y) b <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure x) (Right . There))) b' <- assert_total (subst' b' (add (LogNormalForm' ann) (Left $ pure y) (Right . There))) t <- assert_total (doApp (Sorted.weaken [isSet s] t) x) t <- assert_total (doCast True b b' t) pure (Cnstr $ Lambda s var t) doCastR ty@(Sigma s@(Set k) s' var a b) (Cnstr ty'@(Sigma l l' _ a' b')) t = do let Yes Refl = decEq (s, s') (l, l') | No _ => pure (Ntrl $ CastStuck ty ty' t) t1 <- doFst True (isSet s') t u1 <- assert_total (doCast True a a' t) b <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure t1) Right)) b' <- assert_total (subst' b' (add (LogNormalForm' ann) (Left $ pure u1) Right)) t2 <- doSnd True (isSet s') t u2 <- assert_total (doCast (isSet s') b b' t2) pure (Cnstr $ Pair (Set k) s' Oh u1 u2) doCastR ty@(Sigma Prop s'@(Set k) var a b) (Cnstr ty'@(Sigma Prop l' _ a' b')) t = do let Yes Refl = decEq s' l' | No _ => pure (Ntrl $ CastStuck ty ty' t) b <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure Irrel) Right)) b' <- assert_total (subst' b' (add (LogNormalForm' ann) (Left $ pure Irrel) Right)) t2 <- doSnd False True t u2 <- assert_total (doCast True b b' t2) pure (Cnstr $ Pair Prop (Set k) Oh Irrel u2) doCastR Bool (Cnstr Bool) t = pure t doCastR a (Cnstr b) t = pure (Ntrl $ CastStuck a b t) doCast False a b t = pure Irrel doCast True (Ntrl a) b t = pure (Ntrl $ CastL a b t) doCast True (Cnstr a) b t = doCastR a b t export doEqual : (b : _) -> (a : NormalForm True ctx) -> NormalForm b ctx -> NormalForm b ctx -> LogNormalForm ann True ctx -- Relies heavily on typing doEqualR : (a : Constructor ctx) -> (b : NormalForm True ctx) -> LogNormalForm ann True ctx doEqualR a (Ntrl b) = pure (Ntrl $ EqualR a b) doEqualR (Sort _) (Cnstr (Sort s)) = pure (Cnstr Top) doEqualR ty@(Pi s s' var a b) (Cnstr ty'@(Pi l l' _ a' b')) = let u : NormalForm (isSet s) (isSet s :: ctx) u = point (var, (s ** Refl)) Here in do let Yes Refl = decEq (s, s') (l, l') | No _ => pure (Ntrl $ EqualStuck ty ty') eq1 <- assert_total (doEqual True (cast s) a' a) t <- doCast (isSet s) (weaken [isSet s] a') (weaken [isSet s] a) u b <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure t) (Right . There))) b' <- assert_total (subst' b' (add (LogNormalForm' ann) (Left $ pure u) (Right . There))) eq2 <- assert_total (doEqual True (cast s') b b') pure (Cnstr $ Sigma Prop Prop "_" eq1 (Cnstr $ Unsorted.weaken [False] $ Pi s Prop var a eq2)) doEqualR ty@(Sigma s s' var a b) (Cnstr ty'@(Sigma l l' _ a' b')) = let t : NormalForm (isSet s) (isSet s :: ctx) t = point (var, (s ** Refl)) Here in do let Yes Refl = decEq (s, s') (l, l') | No _ => pure (Ntrl $ EqualStuck ty ty') eq1 <- assert_total (doEqual True (cast s) a a') u <- doCast (isSet s) (weaken [isSet s] a) (weaken [isSet s] a') t b <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure t) (Right . There))) b' <- assert_total (subst' b' (add (LogNormalForm' ann) (Left $ pure u) (Right . There))) eq2 <- assert_total (doEqual True (cast s') b b') pure (Cnstr $ Sigma Prop Prop "_" eq1 (Cnstr $ Unsorted.weaken [False] $ Pi s Prop var a eq2)) doEqualR Bool (Cnstr Bool) = pure (Cnstr Top) doEqualR Top (Cnstr Top) = pure (Cnstr Top) doEqualR Bottom (Cnstr Bottom) = pure (Cnstr Top) doEqualR a (Cnstr b) = pure (Ntrl $ EqualStuck a b) export doEqualSet : (a, b : NormalForm True ctx) -> LogNormalForm ann True ctx doEqualSet (Ntrl a) b = pure (Ntrl $ EqualL a b) doEqualSet (Cnstr a) b = doEqualR a b doEqual False a t u = pure (Cnstr Top) doEqual True (Ntrl a) t u = pure (Ntrl $ Equal a t u) doEqual True (Cnstr (Sort Prop)) t u = do pure (Cnstr $ Sigma Prop Prop "" (Cnstr $ Pi Prop Prop "" t (Sorted.weaken [False] u)) (Cnstr $ Unsorted.weaken [False] $ Pi Prop Prop "" u (Sorted.weaken [False] t))) doEqual True (Cnstr (Sort (Set _))) t u = doEqualSet t u doEqual True (Cnstr (Pi s (Set k) var a b)) t u = let x : NormalForm (isSet s) (isSet s :: ctx) x = point (var, (s ** Refl)) Here in do t <- assert_total (doApp (weaken [isSet s] t) x) u <- assert_total (doApp (weaken [isSet s] u) x) eq <- doEqual True b t u pure (Cnstr $ Pi s Prop var a eq) doEqual True (Cnstr (Sigma s@(Set _) s' var a b)) t u = do t1 <- doFst True (isSet s') t u1 <- doFst True (isSet s') u t2 <- doSnd True (isSet s') t u2 <- doSnd True (isSet s') u eq1 <- doEqual True a t1 u1 bt1 <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure t1) Right)) bu1 <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure u1) Right)) t2' <- doCast (isSet s') bt1 bu1 t2 eq2 <- doEqual (isSet s') (assert_smaller b bu1) t2' u2 pure (Cnstr $ Sigma Prop Prop "_" eq1 (Sorted.weaken [False] eq2)) doEqual True (Cnstr (Sigma Prop (Set k) var a b)) t u = do t2 <- doSnd False True t u2 <- doSnd False True u bt1 <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure $ Irrel) Right)) bu1 <- assert_total (subst' b (add (LogNormalForm' ann) (Left $ pure $ Irrel) Right)) t2' <- doCast True bt1 bu1 t2 eq2 <- doEqual True (assert_smaller b bu1) t2' u2 pure (Cnstr $ Sigma Prop Prop "_" (Cnstr Top) (Sorted.weaken [False] eq2)) doEqual True (Cnstr Bool) t u = do true <- doIf u (Cnstr Top) (Cnstr Bottom) false <- doIf u (Cnstr Bottom) (Cnstr Top) doIf t true false doEqual True (Cnstr a) t u = inScope "wrong constructor for equal" $ fatal a substCnstr : Constructor ~|> Hom (LogNormalForm' ann) (LogConstructor ann) substCnstr (Sort s) f = pure (Sort s) substCnstr (Pi s s' var a b) f = do a <- subst' a f b <- subst' b (lift [(_ ** (var, (s ** Refl)))] f) pure (Pi s s' var a b) substCnstr (Lambda s var t) f = do t <- subst' t (lift [(_ ** (var, (s ** Refl)))] f) pure (Lambda s var t) substCnstr (Sigma s s' var a b) f = do a <- subst' a f b <- subst' b (lift [(_ ** (var, (s ** Refl)))] f) pure (Sigma s s' var a b) substCnstr (Pair s s' prf t u) f = do t <- subst' t f u <- subst' u f pure (Pair s s' prf t u) substCnstr Bool f = pure Bool substCnstr True f = pure True substCnstr False f = pure False substCnstr Top f = pure Top substCnstr Bottom f = pure Bottom substNtrl : Neutral ~|> Hom (LogNormalForm' ann) (LogNormalForm ann True) substNtrl (Var var sort prf i) f = case f i of Left t => t Right j => pure (Ntrl $ Var var sort prf j) substNtrl (App b t u) f = do t <- substNtrl t f u <- subst' u f assert_total (doApp t u) substNtrl (Fst b t) f = do t <- substNtrl t f doFst True b t substNtrl (Snd b t) f = do t <- substNtrl t f doSnd b True $ rewrite orTrueTrue b in t substNtrl (If t u v) f = do t <- substNtrl t f u <- subst' u f v <- subst' v f doIf t u v substNtrl Absurd f = pure (doAbsurd True) substNtrl (Equal a t u) f = do a <- substNtrl a f t <- subst' t f u <- subst' u f doEqual _ a t u substNtrl (EqualL a b) f = do a <- substNtrl a f b <- subst' b f doEqualSet a b substNtrl (EqualR a b) f = do a <- substCnstr a f b <- substNtrl b f doEqualR a b substNtrl (EqualStuck a b) f = do a <- substCnstr a f b <- substCnstr b f pure (Ntrl $ EqualStuck a b) substNtrl (CastL a b t) f = do a <- substNtrl a f b <- subst' b f t <- subst' t f doCast _ a b t substNtrl (CastR a b t) f = do a <- substCnstr a f b <- substNtrl b f t <- subst' t f doCastR a b t substNtrl (CastStuck a b t) f = do a <- substCnstr a f b <- substCnstr b f t <- subst' t f pure (Ntrl $ CastStuck a b t) subst' (Ntrl t) f = substNtrl t f subst' (Cnstr t) f = pure $ Cnstr !(substCnstr t f) subst' Irrel f = pure Irrel export subst : NormalForm ~|> Hom (LogNormalForm ann) (LogNormalForm ann) subst t f = subst' t (Left . f) -- Utilities ------------------------------------------------------------------- export subst1 : {s' : _} -> NormalForm s ctx -> NormalForm s' (s :: ctx) -> LogNormalForm ann s' ctx subst1 t u = subst' u (add (LogNormalForm' ann) (Left $ pure t) Right)