From e2cd078cb16834256439ac775cb8cf1e17679181 Mon Sep 17 00:00:00 2001 From: Greg Brown Date: Wed, 25 Nov 2020 15:45:32 +0000 Subject: Add substitution. --- src/nibble/mod.rs | 149 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 87 insertions(+), 62 deletions(-) (limited to 'src/nibble') diff --git a/src/nibble/mod.rs b/src/nibble/mod.rs index 3fad97b..bb14f7b 100644 --- a/src/nibble/mod.rs +++ b/src/nibble/mod.rs @@ -10,26 +10,37 @@ use syn::{ token, Result, Token, }; +const PREFER_SUBST_OVER_CALL: bool = true; + pub type Epsilon = Token![_]; impl Convert for Epsilon { - fn convert(self, _: &Context) -> ast::Term { - ast::Term::Epsilon(self) + fn convert(&self, _: &mut Context) -> ast::Term { + ast::Term::Epsilon(*self) } } type Ident = syn::Ident; impl Convert for Ident { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; let name = self.to_string(); - if let Some(variable) = context.get(&name) { - Term::Variable(ast::Variable::new(self, variable)) + if let Some(variable) = context.get_binding(&name) { + Term::Variable(ast::Variable::new(self.clone(), variable)) + } else if PREFER_SUBST_OVER_CALL { + if let Some(term) = context.get_variable(&name) { + term.clone() + } else if let Some(term) = context.call_function(&name, std::iter::empty()) { + term + } else { + let span = self.span(); + Term::Call(ast::Call::new(self.clone(), Vec::new(), span)) + } } else { let span = self.span(); - Term::Call(ast::Call::new(self, Vec::new(), span)) + Term::Call(ast::Call::new(self.clone(), Vec::new(), span)) } } } @@ -37,8 +48,8 @@ impl Convert for Ident { type Literal = syn::LitStr; impl Convert for Literal { - fn convert(self, _context: &Context) -> ast::Term { - ast::Term::Literal(self) + fn convert(&self, _: &mut Context) -> ast::Term { + ast::Term::Literal(self.clone()) } } @@ -51,7 +62,10 @@ pub struct Call { impl Call { fn span(&self) -> Span { - self.name.span().join(self.paren_tok.span).unwrap_or_else(Span::call_site) + self.name + .span() + .join(self.paren_tok.span) + .unwrap_or_else(Span::call_site) } } @@ -70,20 +84,25 @@ impl Parse for Call { } impl Convert for Call { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; - let span = self.span(); - Term::Call(ast::Call::new( - self.name, - self.args - .into_pairs() - .map(|pair| match pair { - Pair::Punctuated(t, _) => t.convert(context), - Pair::End(t) => t.convert(context), - }) - .collect(), - span, - )) + let args: Vec<_> = self + .args + .pairs() + .map(|pair| match pair { + Pair::Punctuated(t, _) => t.convert(context), + Pair::End(t) => t.convert(context), + }) + .collect(); + if PREFER_SUBST_OVER_CALL { + if let Some(term) = context.call_function(&self.name.to_string(), args.clone()) { + term + } else { + Term::Call(ast::Call::new(self.name.clone(), args, self.span())) + } + } else { + Term::Call(ast::Call::new(self.name.clone(), args, self.span())) + } } } @@ -97,7 +116,10 @@ pub struct Fix { impl Fix { fn span(&self) -> Span { - self.bracket_token.span.join(self.paren_token.span).unwrap_or_else(Span::call_site) + self.bracket_token + .span + .join(self.paren_token.span) + .unwrap_or_else(Span::call_site) } } @@ -119,14 +141,14 @@ impl Parse for Fix { } impl Convert for Fix { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; let span = self.span(); - let expr = self.expr; + let expr = &self.expr; let arg_name = self.arg.to_string(); Term::Fix(ast::Fix::new( - self.arg, - context.push(arg_name, |c| expr.convert(c)), + self.arg.clone(), + context.push_binding(arg_name, |c| expr.convert(c)), span, )) } @@ -139,7 +161,7 @@ pub struct ParenExpression { } impl ParenExpression { - fn span(&self) -> Span { + pub fn span(&self) -> Span { self.paren_tok.span } } @@ -154,7 +176,7 @@ impl Parse for ParenExpression { } impl Convert for ParenExpression { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { self.expr.convert(context) } } @@ -170,7 +192,7 @@ pub enum Term { } impl Term { - fn span(&self) -> Span { + pub fn span(&self) -> Span { match self { Self::Epsilon(eps) => eps.span, Self::Ident(ident) => ident.span(), @@ -209,7 +231,7 @@ impl Parse for Term { args, }) }) - }else { + } else { Ok(Self::Ident(name)) } } else { @@ -219,7 +241,7 @@ impl Parse for Term { } impl Convert for Term { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { match self { Self::Epsilon(e) => e.convert(context), Self::Ident(i) => i.convert(context), @@ -237,7 +259,7 @@ pub struct Cat { } impl Cat { - fn span(&self) -> Span { + pub fn span(&self) -> Span { let mut pairs = self.terms.pairs(); let mut span = pairs.next().and_then(|pair| match pair { Pair::Punctuated(term, punct) => term.span().join(punct.span), @@ -246,7 +268,9 @@ impl Cat { for pair in pairs { span = span.and_then(|span| match pair { - Pair::Punctuated(term, punct) => span.join(term.span()).and_then(|s| s.join(punct.span)), + Pair::Punctuated(term, punct) => { + span.join(term.span()).and_then(|s| s.join(punct.span)) + } Pair::End(term) => span.join(term.span()), }) } @@ -264,24 +288,23 @@ impl Parse for Cat { } impl Convert for Cat { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; - let mut iter = self.terms.into_pairs(); + let mut iter = self.terms.pairs(); let init = match iter.next().unwrap() { Pair::Punctuated(t, p) => Ok((t.convert(context), p)), Pair::End(t) => Err(t.convert(context)), }; - iter.fold(init, |term, pair|{ + iter.fold(init, |term, pair| { let (fst, punct) = term.unwrap(); match pair { Pair::Punctuated(t, p) => { - Ok((Term::Cat(ast::Cat::new(fst, punct, t.convert(context))), p)) - } - Pair::End(t) => { - Err(Term::Cat(ast::Cat::new(fst, punct, t.convert(context)))) + Ok((Term::Cat(ast::Cat::new(fst, *punct, t.convert(context))), p)) } + Pair::End(t) => Err(Term::Cat(ast::Cat::new(fst, *punct, t.convert(context)))), } - }).unwrap_err() + }) + .unwrap_err() } } @@ -291,7 +314,7 @@ pub struct Alt { } impl Alt { - fn span(&self) -> Span { + pub fn span(&self) -> Span { let mut pairs = self.cats.pairs(); let mut span = pairs.next().and_then(|pair| match pair { Pair::Punctuated(cat, punct) => cat.span().join(punct.span), @@ -300,7 +323,9 @@ impl Alt { for pair in pairs { span = span.and_then(|span| match pair { - Pair::Punctuated(cat, punct) => span.join(cat.span()).and_then(|s| s.join(punct.span)), + Pair::Punctuated(cat, punct) => { + span.join(cat.span()).and_then(|s| s.join(punct.span)) + } Pair::End(cat) => span.join(cat.span()), }) } @@ -318,24 +343,24 @@ impl Parse for Alt { } impl Convert for Alt { - fn convert(self, context: &Context) -> ast::Term { - use ast::Term; - let mut iter = self.cats.into_pairs(); - let init = match iter.next().unwrap() { - Pair::Punctuated(t, p) => Ok((t.convert(context), p)), - Pair::End(t) => Err(t.convert(context)), - }; - iter.fold(init, |cat, pair|{ - let (left, punct) = cat.unwrap(); - match pair { - Pair::Punctuated(t, p) => { - Ok((Term::Alt(ast::Alt::new(left, punct, t.convert(context))), p)) - } - Pair::End(t) => { - Err(Term::Alt(ast::Alt::new(left, punct, t.convert(context)))) - } - } - }).unwrap_err() + fn convert(&self, context: &mut Context) -> ast::Term { + use ast::Term; + let mut iter = self.cats.pairs(); + let init = match iter.next().unwrap() { + Pair::Punctuated(t, p) => Ok((t.convert(context), p)), + Pair::End(t) => Err(t.convert(context)), + }; + iter.fold(init, |cat, pair| { + let (left, punct) = cat.unwrap(); + match pair { + Pair::Punctuated(t, p) => Ok(( + Term::Alt(ast::Alt::new(left, *punct, t.convert(context))), + p, + )), + Pair::End(t) => Err(Term::Alt(ast::Alt::new(left, *punct, t.convert(context)))), + } + }) + .unwrap_err() } } -- cgit v1.2.3