diff options
author | Greg Brown <gmb60@cam.ac.uk> | 2021-04-18 11:50:06 +0100 |
---|---|---|
committer | Greg Brown <gmb60@cam.ac.uk> | 2021-04-30 14:45:07 +0100 |
commit | bf46a471fb268f7c0798a179740b295f8aaa1a31 (patch) | |
tree | 74aae5dd340d13d5de26b6e5365f0e700c0d5f24 | |
parent | d93ba45b0a952dea06f5cc5326eefb0818525912 (diff) |
Update AST and parser for lambda expressions.
-rw-r--r-- | src/chomp/ast/mod.rs | 104 | ||||
-rw-r--r-- | src/nibble/convert.rs | 210 | ||||
-rw-r--r-- | src/nibble/mod.rs | 291 |
3 files changed, 315 insertions, 290 deletions
diff --git a/src/chomp/ast/mod.rs b/src/chomp/ast/mod.rs index 6d547a3..d833da4 100644 --- a/src/chomp/ast/mod.rs +++ b/src/chomp/ast/mod.rs @@ -15,14 +15,12 @@ pub type Literal = String; #[derive(Clone, Debug)] pub struct Cat { pub first: Box<NamedExpression>, - pub punct: Option<Span>, - pub second: Box<NamedExpression>, pub rest: Vec<(Option<Span>, NamedExpression)>, } impl Display for Cat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "({} . {}", self.first, self.second)?; + write!(f, "({}", self.first)?; for (_, other) in &self.rest { write!(f, " . {}", other)?; } @@ -33,7 +31,6 @@ impl Display for Cat { impl PartialEq for Cat { fn eq(&self, other: &Self) -> bool { self.first == other.first - && self.second == other.second && self.rest.len() == other.rest.len() && self .rest @@ -48,14 +45,12 @@ impl Eq for Cat {} #[derive(Clone, Debug)] pub struct Alt { pub first: Box<NamedExpression>, - pub punct: Option<Span>, - pub second: Box<NamedExpression>, - pub rest: Vec<(Option<Span>, NamedExpression)>, + pub rest: Vec<(Option<Span>, NamedExpression)> } impl Display for Alt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "({} | {}", self.first, self.second)?; + write!(f, "({}", self.first)?; for (_, other) in &self.rest { write!(f, " | {}", other)?; } @@ -66,7 +61,6 @@ impl Display for Alt { impl PartialEq for Alt { fn eq(&self, other: &Self) -> bool { self.first == other.first - && self.second == other.second && self.rest.len() == other.rest.len() && self .rest @@ -78,29 +72,17 @@ impl PartialEq for Alt { impl Eq for Alt {} -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Fix { - pub arg: Option<Name>, pub inner: Box<NamedExpression>, } impl Display for Fix { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match &self.arg { - Some(arg) => write!(f, "[{}]({})", arg, self.inner), - None => write!(f, "[]({})", self.inner), - } - } -} - -impl PartialEq for Fix { - fn eq(&self, other: &Self) -> bool { - self.inner == other.inner + write!(f, "!{}", self.inner) } } -impl Eq for Fix {} - #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub struct Variable { pub index: usize, @@ -112,41 +94,40 @@ impl Display for Variable { } } -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -pub struct Parameter { - pub index: usize, -} - -impl Display for Parameter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "<{}>", self.index) - } -} - -/// A macro invocation. +/// A function invocation. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Call { - pub name: Name, + pub on: Box<NamedExpression>, pub args: Vec<NamedExpression>, } impl Display for Call { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.name)?; + write!(f, "({}", self.on)?; - let mut iter = self.args.iter(); + for arg in self.args { + write!(f, " {}", arg)?; + } - if let Some(arg) = iter.next() { - write!(f, "({}", arg)?; + write!(f, ")") + } +} - for arg in iter { - write!(f, ", {}", arg)?; - } +/// A function abstraction. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Lambda { + pub first: Name, + pub rest: Vec<Name>, + pub inner: Box<NamedExpression>, +} - write!(f, ")") - } else { - Ok(()) +impl Display for Lambda { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "/{}", self.first)?; + for name in self.rest { + write!(f, ", {}", name)?; } + write!(f, "/ {}", self.inner) } } @@ -164,10 +145,10 @@ pub enum Expression { Fix(Fix), /// A fixed point variable. Variable(Variable), - /// A formal parameter. - Parameter(Parameter), - /// A macro invocation. + /// A function invocation. Call(Call), + /// A function abstraction. + Lambda(Lambda), } impl Display for Expression { @@ -179,7 +160,7 @@ impl Display for Expression { Self::Alt(a) => a.fmt(f), Self::Fix(x) => x.fmt(f), Self::Variable(v) => v.fmt(f), - Self::Parameter(p) => p.fmt(f), + Self::Lambda(p) => p.fmt(f), Self::Call(c) => c.fmt(f), } } @@ -224,8 +205,8 @@ impl PartialEq for Expression { false } } - Self::Parameter(p) => { - if let Self::Parameter(them) = other { + Self::Lambda(p) => { + if let Self::Lambda(them) = other { p == them } else { false @@ -280,9 +261,9 @@ impl From<Variable> for Expression { } } -impl From<Parameter> for Expression { - fn from(param: Parameter) -> Self { - Self::Parameter(param) +impl From<Lambda> for Expression { + fn from(lambda: Lambda) -> Self { + Self::Lambda(lambda) } } @@ -317,11 +298,16 @@ impl PartialEq for NamedExpression { impl Eq for NamedExpression {} #[derive(Clone, Debug)] -pub struct Function { +pub struct Let { pub name: Name, - pub params: Vec<Option<Name>>, - pub expr: NamedExpression, - pub span: Option<Span>, + pub val: NamedExpression, + pub inner: Box<TopLevel>, +} + +#[derive(Clone, Debug)] +pub enum TopLevel { + Let(Let), + Goal(NamedExpression), } impl PartialEq for Function { diff --git a/src/nibble/convert.rs b/src/nibble/convert.rs index afebafe..1e1ea08 100644 --- a/src/nibble/convert.rs +++ b/src/nibble/convert.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fmt, mem}; +use std::{fmt, mem}; use proc_macro2::Span; use syn::{punctuated::Pair, Token}; @@ -8,62 +8,42 @@ use crate::chomp::{ Name, }; -use super::{Alt, Call, Cat, Fix, Ident, Labelled, ParenExpression, Term}; - -#[derive(Clone, Copy, Debug)] -pub enum Binding { - Variable(usize), - Parameter(usize), - Global, -} +use super::{Alt, Call, Cat, Expression, Fix, Ident, Labelled, Lambda, ParenExpression, Term}; #[derive(Debug, Default)] pub struct Context { - names: HashMap<String, Binding>, - vars: usize, + bindings: Vec<Name>, } impl Context { - pub fn new<I: IntoIterator<Item = Name>>(globals: &[Name], params: I) -> Self { - let mut names = HashMap::new(); - for global in globals { - names.insert(global.to_string(), Binding::Global); - } - - for (index, param) in params.into_iter().enumerate() { - names.insert(param.to_string(), Binding::Parameter(index)); - } - - Self { names, vars: 0 } + pub fn new() -> Self { + Self::default() } - pub fn lookup(&self, name: &Name) -> Option<Binding> { - // we make variable binding cheaper by inserting wrong and pulling right. - match self.names.get(&name.to_string()).copied() { - Some(Binding::Variable(index)) => Some(Binding::Variable(self.vars - index - 1)), - Some(Binding::Parameter(index)) => Some(Binding::Parameter(index)), - Some(Binding::Global) => Some(Binding::Global), - None => None, - } + pub fn lookup(&self, name: &Name) -> Option<usize> { + self.bindings + .iter() + .enumerate() + .rfind(|(_, n)| *n == name) + .map(|(idx, _)| idx) } - pub fn with_variable<F: FnOnce(&mut Self) -> R, R>(&mut self, name: &Name, f: F) -> R { - let old = self - .names - .insert(name.to_string(), Binding::Variable(self.vars)); + pub fn push_variable(&mut self, name: Name) { + self.bindings.push(name); + } - // we make variable binding cheaper by inserting wrong and pulling right. - // we should increment all values in names instead, but that's slow - self.vars += 1; + pub fn with_variable<F: FnOnce(&mut Self) -> R, R>(&mut self, name: Name, f: F) -> R { + self.bindings.push(name); let res = f(self); - self.vars -= 1; - - if let Some(val) = old { - self.names.insert(name.to_string(), val); - } else { - self.names.remove(&name.to_string()); - } + self.bindings.pop(); + res + } + pub fn with_variables<I: IntoIterator<Item = Name>, F: FnOnce(&mut Self) -> R, R>(&mut self, names: I, f: F) -> R { + let len = self.bindings.len(); + self.bindings.extend(names); + let res = f(self); + self.bindings.resize_with(len, || unreachable!()); res } } @@ -73,6 +53,8 @@ pub enum ConvertError { UndeclaredName(Box<Name>), EmptyCat(Option<Span>), EmptyAlt(Option<Span>), + EmptyCall(Option<Span>), + MissingArgs(Option<Span>), } impl From<ConvertError> for syn::Error { @@ -80,7 +62,10 @@ impl From<ConvertError> for syn::Error { let msg = e.to_string(); let span = match e { ConvertError::UndeclaredName(name) => name.span(), - ConvertError::EmptyCat(span) | ConvertError::EmptyAlt(span) => span, + ConvertError::EmptyCat(span) + | ConvertError::EmptyAlt(span) + | ConvertError::EmptyCall(span) + | ConvertError::MissingArgs(span) => span, }; Self::new(span.unwrap_or_else(Span::call_site), msg) @@ -99,6 +84,12 @@ impl fmt::Display for ConvertError { Self::EmptyAlt(_) => { write!(f, "alternation has no elements") } + Self::EmptyCall(_) => { + write!(f, "call has no function") + } + Self::MissingArgs(_) => { + write!(f, "call has no arguments") + } } } } @@ -114,49 +105,13 @@ impl Convert for Ident { let span = Some(self.span()); let name = self.into(); - let binding = context + let index = context .lookup(&name) .ok_or_else(|| ConvertError::UndeclaredName(Box::new(name.clone())))?; - Ok(match binding { - Binding::Variable(index) => NamedExpression { - name: Some(name), - expr: ast::Variable { index }.into(), - span, - }, - Binding::Parameter(index) => NamedExpression { - name: Some(name), - expr: ast::Parameter { index }.into(), - span, - }, - Binding::Global => NamedExpression { - name: None, - expr: ast::Call { - name, - args: Vec::new(), - } - .into(), - span, - }, - }) - } -} - -impl Convert for Call { - fn convert(self, context: &mut Context) -> Result<NamedExpression, ConvertError> { - let span = self.span(); - let args = self - .args - .into_iter() - .map(|arg| arg.convert(context)) - .collect::<Result<_, _>>()?; Ok(NamedExpression { - name: None, - expr: ast::Call { - name: self.name.into(), - args, - } - .into(), + name: Some(name), + expr: ast::Variable { index }.into(), span, }) } @@ -165,13 +120,10 @@ impl Convert for Call { impl Convert for Fix { fn convert(self, context: &mut Context) -> Result<NamedExpression, ConvertError> { let span = self.span(); - let expr = self.expr; - let arg = self.arg.into(); - let inner = context.with_variable(&arg, |context| expr.convert(context))?; + let inner = self.expr.convert(context)?; Ok(NamedExpression { name: None, expr: ast::Fix { - arg: Some(arg), inner: Box::new(inner), } .into(), @@ -200,17 +152,43 @@ impl Convert for Term { expr: l.value().into(), span: Some(l.span()), }), - Self::Call(c) => c.convert(context), Self::Fix(f) => f.convert(context), Self::Parens(p) => p.convert(context), } } } +impl Convert for Call { + fn convert(self, context: &mut Context) -> Result<NamedExpression, ConvertError> { + let span = self.span(); + let mut iter = self.0.into_iter(); + let on = iter + .next() + .ok_or_else(|| ConvertError::EmptyCall(span))? + .convert(context)?; + let args = iter + .map(|arg| arg.convert(context)) + .collect::<Result<Vec<_>, _>>()?; + if args.is_empty() { + Err(ConvertError::MissingArgs(span)) + } else { + Ok(NamedExpression { + name: None, + expr: ast::Call { + on: Box::new(on), + args, + } + .into(), + span, + }) + } + } +} + impl Convert for Cat { fn convert(self, context: &mut Context) -> Result<NamedExpression, ConvertError> { fn convert_pair( - pair: Pair<Term, Token![.]>, + pair: Pair<Call, Token![.]>, context: &mut Context, ) -> Result<(NamedExpression, Option<Span>), ConvertError> { match pair { @@ -227,18 +205,17 @@ impl Convert for Cat { .ok_or(ConvertError::EmptyCat(span)) .and_then(|pair| convert_pair(pair, context))?; - let mut rest = iter.map(|pair| { - convert_pair(pair, context).map(|(snd, p)| (mem::replace(&mut punct, p), snd)) - }); + let mut rest = iter + .map(|pair| { + convert_pair(pair, context).map(|(snd, p)| (mem::replace(&mut punct, p), snd)) + }) + .peekable(); - if let Some(res) = rest.next() { - let (punct, second) = res?; + if let Some(_) = rest.peek() { Ok(NamedExpression { name: None, expr: ast::Cat { first: Box::new(first), - punct, - second: Box::new(second), rest: rest.collect::<Result<_, _>>()?, } .into(), @@ -284,18 +261,17 @@ impl Convert for Alt { .ok_or(ConvertError::EmptyAlt(span)) .and_then(|pair| convert_pair(pair, context))?; - let mut rest = iter.map(|pair| { - convert_pair(pair, context).map(|(snd, p)| (mem::replace(&mut punct, p), snd)) - }); + let mut rest = iter + .map(|pair| { + convert_pair(pair, context).map(|(snd, p)| (mem::replace(&mut punct, p), snd)) + }) + .peekable(); - if let Some(res) = rest.next() { - let (punct, second) = res?; + if let Some(_) = rest.peek() { Ok(NamedExpression { name: None, expr: ast::Alt { first: Box::new(first), - punct, - second: Box::new(second), rest: rest.collect::<Result<_, _>>()?, } .into(), @@ -306,3 +282,29 @@ impl Convert for Alt { } } } + +impl Convert for Lambda { + fn convert(self, context: &mut Context) -> Result<NamedExpression, ConvertError> { + let span = self.span(); + let mut names = self.args.into_iter().map(Name::from); + let expr = self.expr; + let inner = context.with_variables(names.clone(), |ctx| expr.convert(ctx))?; + let first = names.next().unwrap(); + let rest = names.collect(); + Ok(NamedExpression { + name: None, + expr: ast::Lambda { first, rest, inner: Box::new(inner)}.into(), + span, + }) + + } +} + +impl Convert for Expression { + fn convert(self, context: &mut Context) -> Result<NamedExpression, ConvertError> { + match self { + Expression::Alt(a) => a.convert(context), + Expression::Lambda(l) => l.convert(context), + } + } +} diff --git a/src/nibble/mod.rs b/src/nibble/mod.rs index 3f5c892..dbb05b0 100644 --- a/src/nibble/mod.rs +++ b/src/nibble/mod.rs @@ -13,7 +13,7 @@ use syn::{ LitStr, Token, }; -use crate::chomp::{ast, Name}; +use crate::chomp::{Name, ast::{self, TopLevel}}; use convert::{Context, Convert, ConvertError}; @@ -69,63 +69,29 @@ impl<T: fmt::Debug> fmt::Debug for ArgList<T> { } } -#[derive(Clone, Debug)] -pub struct Call { - pub name: Ident, - pub args: ArgList<Expression>, -} - -impl Call { - pub fn span(&self) -> Option<Span> { - self.name.span().join(self.args.span()) - } -} - -impl Parse for Call { - fn parse(input: ParseStream<'_>) -> syn::Result<Self> { - let name = input.call(Ident::parse_any)?; - let args = input.parse()?; - Ok(Self { name, args }) - } -} - #[derive(Clone)] pub struct Fix { - bracket_token: Bracket, - pub arg: Ident, - paren_token: Paren, - pub expr: Expression, + bang_token: Token![!], + pub expr: Box<Term>, } impl Fix { pub fn span(&self) -> Option<Span> { - self.bracket_token.span.join(self.paren_token.span) + self.bang_token.span.join(self.expr.span()?) } } impl Parse for Fix { fn parse(input: ParseStream<'_>) -> syn::Result<Self> { - let arg; - let bracket_token = bracketed!(arg in input); - let arg = arg.call(Ident::parse_any)?; - let expr; - let paren_token = parenthesized!(expr in input); - let expr = expr.parse()?; - Ok(Self { - bracket_token, - arg, - paren_token, - expr, - }) + let bang_token = input.parse()?; + let expr = input.parse()?; + Ok(Self { bang_token, expr }) } } impl fmt::Debug for Fix { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Fix") - .field("arg", &self.arg) - .field("expr", &self.expr) - .finish() + f.debug_struct("Fix").field("expr", &self.expr).finish() } } @@ -157,7 +123,6 @@ pub enum Term { Epsilon(Epsilon), Ident(Ident), Literal(Literal), - Call(Call), Fix(Fix), Parens(ParenExpression), } @@ -168,7 +133,6 @@ impl Term { Self::Epsilon(e) => Some(e.span), Self::Ident(i) => Some(i.span()), Self::Literal(l) => Some(l.span()), - Self::Call(c) => c.span(), Self::Fix(f) => f.span(), Self::Parens(p) => Some(p.paren_token.span), } @@ -183,18 +147,12 @@ impl Parse for Term { input.parse().map(Self::Epsilon) } else if lookahead.peek(LitStr) { input.parse().map(Self::Literal) - } else if lookahead.peek(Bracket) { + } else if lookahead.peek(Token![!]) { input.parse().map(Self::Fix) } else if lookahead.peek(Paren) { input.parse().map(Self::Parens) } else if lookahead.peek(Ident::peek_any) { - let name = input.call(Ident::parse_any)?; - - if input.peek(Paren) { - input.parse().map(|args| Self::Call(Call { name, args })) - } else { - Ok(Self::Ident(name)) - } + input.call(Ident::parse_any).map(Self::Ident) } else { Err(lookahead.error()) } @@ -207,15 +165,47 @@ impl fmt::Debug for Term { Term::Epsilon(_) => write!(f, "Term::Epsilon"), Term::Ident(i) => write!(f, "Term::Ident({:?})", i), Term::Literal(l) => write!(f, "Term::Literal({:?})", l.value()), - Term::Call(c) => write!(f, "Term::Call({:?})", c), Term::Fix(x) => write!(f, "Term::Fix({:?})", x), Term::Parens(p) => write!(f, "Term::Parens({:?})", p), } } } +#[derive(Clone, Debug)] +pub struct Call(pub Vec<Term>); + +impl Call { + pub fn span(&self) -> Option<Span> { + let mut iter = self.0.iter(); + let first = iter.next()?.span()?; + iter.try_fold(first, |span, t| t.span().and_then(|s| span.join(s))) + } +} + +impl Parse for Call { + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let mut out = Vec::new(); + out.push(input.parse()?); + loop { + let lookahead = input.lookahead1(); + if lookahead.peek(Token![_]) + || lookahead.peek(LitStr) + || lookahead.peek(Token![!]) + || lookahead.peek(Paren) + || lookahead.peek(Ident::peek_any) + { + out.push(input.parse()?); + } else { + break; + } + } + + Ok(Self(out)) + } +} + #[derive(Clone)] -pub struct Cat(pub Punctuated<Term, Token![.]>); +pub struct Cat(pub Punctuated<Call, Token![.]>); impl Parse for Cat { fn parse(input: ParseStream<'_>) -> syn::Result<Self> { @@ -335,59 +325,60 @@ impl fmt::Debug for Alt { f.debug_list().entries(self.0.iter()).finish() } } - -pub type Expression = Alt; - #[derive(Clone)] -pub struct LetStatement { - let_token: Token![let], - name: Ident, - args: Option<ArgList<Ident>>, - eq_token: Token![=], - expr: Expression, - semi_token: Token![;], +pub struct Lambda { + slash_tok_left: Token![/], + pub args: ArgList<Ident>, + slash_tok_right: Token![/], + pub expr: Alt, } -impl LetStatement { +impl Lambda { pub fn span(&self) -> Option<Span> { - self.let_token.span.join(self.semi_token.span) + self.slash_tok_left.span.join(self.expr.span()?) } } -impl Parse for LetStatement { +impl Parse for Lambda { fn parse(input: ParseStream<'_>) -> syn::Result<Self> { - let let_token = input.parse()?; - let name = input.call(Ident::parse_any)?; - let args = if input.peek(Paren) { - Some(input.parse()?) - } else { - None - }; - let eq_token = input.parse()?; + let slash_tok_left = input.parse()?; + let args = input.parse()?; + let slash_tok_right = input.parse()?; let expr = input.parse()?; - let semi_token = input.parse()?; - Ok(Self { - let_token, - name, + slash_tok_left, args, - eq_token, + slash_tok_right, expr, - semi_token, }) } } -impl fmt::Debug for LetStatement { +impl fmt::Debug for Lambda { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("LetStatement") - .field("name", &self.name) + f.debug_struct("Lambda") .field("args", &self.args) .field("expr", &self.expr) .finish() } } +#[derive(Clone, Debug)] +pub enum Expression { + Alt(Alt), + Lambda(Lambda), +} + +impl Parse for Expression { + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + if input.peek(Token![/]) { + input.parse().map(Self::Lambda) + } else { + input.parse().map(Self::Alt) + } + } +} + #[derive(Clone)] pub struct GoalStatement { match_token: Token![match], @@ -417,58 +408,104 @@ impl fmt::Debug for GoalStatement { } } +#[derive(Clone)] +pub struct LetStatement { + let_token: Token![let], + name: Ident, + args: Option<ArgList<Ident>>, + eq_token: Token![=], + expr: Expression, + semi_token: Token![;], + next: Box<Statement>, +} + +impl Parse for LetStatement { + fn parse(input: ParseStream<'_>) -> syn::Result<Self> { + let let_token = input.parse()?; + let name = input.call(Ident::parse_any)?; + let args = if input.peek(Paren) { + Some(input.parse()?) + } else { + None + }; + let eq_token = input.parse()?; + let expr = input.parse()?; + let semi_token = input.parse()?; + let next = Box::new(input.parse()?); + + Ok(Self { + let_token, + name, + args, + eq_token, + expr, + semi_token, + next, + }) + } +} + +impl fmt::Debug for LetStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("LetStatement") + .field("name", &self.name) + .field("args", &self.args) + .field("expr", &self.expr) + .field("next", &self.next) + .finish() + } +} + #[derive(Clone, Debug)] -pub struct File { - lets: Vec<LetStatement>, - goal: GoalStatement, -} - -impl File { - pub fn convert(self) -> Result<(Vec<ast::Function>, ast::NamedExpression), ConvertError> { - let mut names = Vec::new(); - let mut map = Vec::new(); - for stmt in self.lets { - let span = stmt.span(); - let name: Name = stmt.name.into(); - let params = stmt - .args - .into_iter() - .flat_map(ArgList::into_iter) - .map(Name::from); - let mut context = Context::new(&names, params.clone()); - let mut expr = stmt.expr.convert(&mut context)?; - names.push(name.clone()); - expr.name = Some(name.clone()); - map.push(ast::Function { - name, - params: params.map(Some).collect(), - expr, - span, - }); +pub enum Statement { + Goal(GoalStatement), + Let(LetStatement), +} + +impl Statement { + pub fn convert(self) -> Result<TopLevel, ConvertError> { + let mut stmt = self; + let mut context = Context::new(); + let mut name_val = Vec::new(); + while let Self::Let(let_stmt) = stmt { + let mut val = match let_stmt.args { + Some(args) => { + todo!() + } + None => let_stmt.expr.convert(&mut context), + }?; + let name: Name = let_stmt.name.into(); + val.name = val.name.or_else(|| Some(name.clone())); + context.push_variable(name.clone()); + name_val.push((name, val)); + stmt = *let_stmt.next; } - let mut context = Context::new(&names, Vec::new()); - let goal = self.goal.expr.convert(&mut context)?; - Ok((map, goal)) + let goal = match stmt { + Statement::Goal(goal) => TopLevel::Goal(goal.expr.convert(&mut context)?), + Statement::Let(_) => unreachable!(), + }; + + Ok(name_val.into_iter().rfold(goal, |inner, (name, val)| { + TopLevel::Let(ast::Let { + name, + val, + inner: Box::new(inner), + }) + })) } } -impl Parse for File { +impl Parse for Statement { fn parse(input: ParseStream<'_>) -> syn::Result<Self> { - let mut lets = Vec::new(); let mut lookahead = input.lookahead1(); - while lookahead.peek(Let) { - lets.push(input.parse()?); - lookahead = input.lookahead1(); - } - - let goal = if lookahead.peek(Match) { - input.parse()? + if lookahead.peek(Let) { + input.parse().map(Self::Let) + } else if lookahead.peek(Match) { + input.parse().map(Self::Goal) } else { - return Err(lookahead.error()); - }; - - Ok(Self { lets, goal }) + Err(lookahead.error()) + } } } |