Automatic calculator from Bird 1998 added by Zipheir on Wed Oct 17 03:01:20 2018

-- Prover/calculator from Richard Bird, _Introduction to Functional
-- Programming Using Haskell_ (Prentice Hall, 1998), ch. 12.
module Calculator where

import Data.List
import Control.Applicative
import Parsers

data Expr = Var VarName | Con ConName [Expr] | Compose [Expr]
  deriving (Eq)

-- For simplicity, variable names are restricted to single chars.
type VarName = Char
type ConName = String

-- Expr has two datatype invariants which are maintained by compose.
--   (i) The expression Compose xs is valid only if the length of
--       xs is at least two.
--   (ii) No expression of the form Compose xs contains an element
--        that is itself of the form Compose ys.
compose :: [Expr] -> Expr
compose xs = if singleton xs
                then head xs
                else Compose (concat (map decompose xs))

singleton :: [a] -> Bool
singleton xs = length xs == 1

decompose :: Expr -> [Expr]
decompose (Var v)      = [Var v]
decompose (Con f xs)   = [Con f xs]
decompose (Compose xs) = xs

complexity :: Expr -> Int
complexity (Var _)      = 1
complexity (Con _ _)    = 1
complexity (Compose xs) = length xs

printExpr :: Expr -> String
printExpr (Var v) = [v]

printExpr (Con f xs)
  | null xs   = f
  | simple xs = f ++ " " ++ printExpr (head xs)
  | otherwise = f ++ "(" ++ intercalate ", " (map printExpr xs) ++ ")"

printExpr (Compose xs) = intercalate "." (map printExpr xs)

simple :: [Expr] -> Bool
simple xs = singleton xs && simpleton (head xs)

simpleton :: Expr -> Bool
simpleton (Var _)     = True
simpleton (Con _ xs)  = null xs
simpleton (Compose _) = False


-- 12.2.1 Parsing expressions

parseExpr :: String -> Expr
parseExpr = applyParser expr

expr :: Parser Expr
expr = do xs <- somewith (symbol ".") term
          return (compose xs)

term :: Parser Expr
term = do space
          c <- letter
          cs <- many alphanum
          if null cs
             then return (Var c)
             else do xs <- argument
                     return (Con (c:cs) xs)

argument :: Parser [Expr]
argument = tuple <|> (notuple <|> return [])

tuple :: Parser [Expr]
tuple = do symbol "("
           xs <- somewith (symbol ",") expr
           symbol ")"
           return xs

notuple :: Parser [Expr]
notuple = do space
             c <- letter
             cs <- many alphanum
             if null cs
                then return [Var c]
                else return [Con (c:cs) []]

parseEqn :: String -> (Expr, Expr)
parseEqn = applyParser eqn

eqn :: Parser (Expr, Expr)
eqn = do space
         x <- expr
         symbol "="
         y <- expr
         return (x,y)


-- 12.2.2  Laws

type Law = (LawName, Expr, Expr)
type LawName = String

-- Basic laws always reduce the complexity of an expression.
basicLaw :: Law -> Bool
basicLaw (_, lhs, rhs) = complexity lhs > complexity rhs

parseLaw :: String -> Law
parseLaw = applyParser law

law :: Parser Law
law = do space
         name <- some (sat (/= ':'))
         symbol ":"
         (x, y) <- eqn
         return (name, x, y)


-- 12.2.3  Calculations

-- A calculation is a starting expression and a series of steps.
type Calculation = (Expr, [Step])
type Step = (LawName, Expr)

conclusion :: Calculation -> Expr
conclusion (x, [])    = x
conclusion (_, steps) = snd (last steps)

paste :: Calculation -> Calculation -> Calculation
paste lhc rhc = (fst lhc, snd lhc ++ link x y ++ shuffle rhc)
                where x = conclusion lhc
                      y = conclusion rhc

-- Generate a failure line if two conclusions don't match.
link :: Expr -> Expr -> [Step]
link x y
  | x == y    = []
  | otherwise = [("... ??? ...", y)]

-- shuffle reverses the steps of a calculation, taking
--
-- (x, [(r₁, y₁), …, (rₙ, yₙ)]) to [(rₙ, yₙ₋₁), …, (r₂, y₁), (r₁, x)].
--
-- Recall that x is the starting expression.
shuffle :: Calculation -> [Step]
shuffle (x, ss) = snd (foldl shunt (x, []) ss)
                  where shunt (x, rs) (r, y) = (y, (r, x) : rs)

printCalc :: Calculation -> String
printCalc (x, ss) = "\n  " ++ printExpr x ++
                    "\n" ++ concat (map printStep ss)

printStep :: Step -> String
printStep (why, x) = "=  {" ++ why ++ "}\n" ++
                     "  " ++ printExpr x ++ "\n"


-- 12.3.1  Substitutions

-- A substitution is represented by a list of bindings. An empty
-- list represents the identity substitution.
type Subst = [(VarName, Expr)]

binding :: Subst -> VarName -> Expr
binding [] v = Var v
binding ((u, x) : s) v = if u == v then x else binding s v

applySub :: Subst -> Expr -> Expr
applySub s (Var v)      = binding s v
applySub s (Con f xs)   = Con f (map (applySub s) xs)
applySub s (Compose xs) = compose (map (applySub s) xs)

-- Extend a substitution with a new binding. extend returns [] (the
-- identity substitution) if the new binding is incompatible with the
-- given substitution, and a singleton list otherwise.
extend :: Subst -> (VarName, Expr) -> [Subst]
extend s (v, x)
  | y == x     = [s]
  | y == Var v = [(v, x) : s]
  | otherwise  = []
  where y = binding s v


-- 12.3.2  Matching

match :: (Expr, Expr) -> [Subst]
match = xmatch []


xmatch :: Subst -> (Expr, Expr) -> [Subst]
xmatch s (Var v, x)           = extend s (v, x)

xmatch _ (Con _ _, Var _)     = []
xmatch _ (Con _ _, Compose _) = []

xmatch s (Con f xs, Con g ys) =
  if f == g then xmatchlist s (zip xs ys) else []

xmatch _ (Compose _, Var _)   = []
xmatch _ (Compose _, Con _ _) = []

xmatch s (Compose xs, Compose ys) =
  concat (map (xmatchlist s) (align xs ys))


align :: [Expr] -> [Expr] -> [[(Expr, Expr)]]
align xs ys = [zip xs (map compose zs) | zs <- parts (length xs) ys]

parts :: Int -> [a] -> [[[a]]]
parts 0 []     = [[]]
parts 0 (x:xs) = []
parts n []     = []
parts n (x:xs) = map (new x) (parts (n - 1) xs) ++
                 map (glue x) (parts n xs)

new :: a -> [[a]] -> [[a]]
new x yss = [x] : yss

glue :: a -> [[a]] -> [[a]]
glue x (ys:yss) = (x:ys) : yss

xmatchlist :: Subst -> [(Expr, Expr)] -> [Subst]
xmatchlist s []       = [s]
xmatchlist s (xy:xys) = concat [xmatchlist t xys | t <- xmatch s xy]


-- 12.4.0  Subexpressions

type SubExpr = (Location, Expr)

data Location = All | Seg Int Int | Pos Int Location

subexprs :: Expr -> [SubExpr]
subexprs (Var v)      = [(All, Var v)]
subexprs (Con f xs)   = [(All, Con f xs)] ++ subterms xs
subexprs (Compose xs) = [(All, Compose xs)] ++ segments xs ++
                        subterms xs

subterms :: [Expr] -> [SubExpr]
subterms xs = [(Pos j loc, y) | j <- [0 .. n - 1],
                                (loc, y) <- subexprs (xs !! j)]
              where n = length xs

segments :: [Expr] -> [SubExpr]
segments xs
  = [(Seg j k, Compose (take k (drop j xs))) | k <- [2 .. n - 1],
                                               j <- [0 .. n - k]]
    where n = length xs

replace :: Expr -> Location -> Expr -> Expr
replace _ All y = y

replace (Con f xs) (Pos j loc) y
  = Con f (take j xs ++ [replace (xs !! j) loc y] ++ drop (j + 1) xs)
replace (Compose xs) (Pos j loc) y
  = compose (take j xs ++ [replace (xs !! j) loc y] ++ drop (j + 1) xs)
replace (Compose xs) (Seg j k) y
  = compose (take j xs ++ [y] ++ drop (j + k) xs)


-- 12.4.1  Rewriting

calculate :: ([Law], [Law]) -> Expr -> Calculation
calculate pls x = (x, repeatedly (rewrites pls) x)

rewrites :: ([Law], [Law]) -> Expr -> [Step]
rewrites (llaws, rlaws) x
  = concat ([rewrite law sx x | law <- llaws, sx <- subexprs x]
            ++ [rewrite law sx x | sx <- subexprs x, law <- rlaws])

rewrite :: Law -> SubExpr -> Expr -> [Step]
rewrite (name, lhs, rhs) (loc, y) x
  = [(name, replace x loc (applySub s rhs)) | s <- match (lhs, y)]

repeatedly :: (Expr -> [Step]) -> Expr -> [Step]
repeatedly rws x = if null steps
                      then []
                      else (n, y) : repeatedly rws y
                   where steps  = rws x
                         (n, y) = head steps


-- User interface

prove :: [Law] -> String -> IO ()
prove laws = putStr . printCalc . proveEqn laws . parseEqn

proveEqn :: [Law] -> (Expr, Expr) -> Calculation
proveEqn laws (lhs, rhs) =
  paste (calculate plaws lhs) (calculate plaws rhs)
  where plaws = partition basicLaw laws