From e2cd078cb16834256439ac775cb8cf1e17679181 Mon Sep 17 00:00:00 2001 From: Greg Brown Date: Wed, 25 Nov 2020 15:45:32 +0000 Subject: Add substitution. --- src/ast/convert.rs | 115 +++++++++++++++++++++++++++++++++++++---- src/main.rs | 2 +- src/nibble/mod.rs | 149 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 193 insertions(+), 73 deletions(-) diff --git a/src/ast/convert.rs b/src/ast/convert.rs index c27c75f..f828a16 100644 --- a/src/ast/convert.rs +++ b/src/ast/convert.rs @@ -1,8 +1,16 @@ +use std::borrow::Borrow; +use std::collections::HashMap; +use std::hash::Hash; +use std::mem; +use std::rc::Rc; + use super::Term; -#[derive(Clone, Debug, Default)] +#[derive(Debug, Default)] pub struct Context { - vars: Vec, + bindings: Vec, + variables: HashMap, + functions: HashMap, Vec)>, } impl Context { @@ -19,6 +27,9 @@ impl Context { Self::default() } + /// # Errors + /// Returns [`None`] if `name.is_empty()` or if `name` is unbound. + /// /// # Examples /// ``` /// use chomp::ast::convert::Context; @@ -36,10 +47,14 @@ impl Context { /// /// assert_eq!(context.get("x"), None); /// ``` - pub fn get>(&self, name: &T) -> Option { - let mut iter = self.vars.iter(); + pub fn get_binding>(&self, name: &T) -> Option { + let mut iter = self.bindings.iter(); let mut pos = 0; + if name == "" { + return None; + } + while let Some(var) = iter.next_back() { if T::eq(&name, &var) { return Some(pos); @@ -51,6 +66,9 @@ impl Context { None } + /// # Panics + /// If `name.is_empty()`. + /// /// # Examples /// ``` /// use chomp::ast::convert::Context; @@ -64,13 +82,90 @@ impl Context { /// /// assert_eq!(context.get("x"), None); /// ``` - pub fn push T, T>(&self, var: String, f: F) -> T { - let mut context = self.clone(); - context.vars.push(var); - f(&context) + pub fn push_binding T, T>(&mut self, name: String, f: F) -> T { + if name.is_empty() { + panic!() + } + + self.bindings.push(name); + let res = f(self); + self.bindings.pop(); + res + } + + /// # Errors + /// If `name == "".to_owned().borrow()` or `name` is unbound. + pub fn get_variable(&self, name: &T) -> Option<&Term> + where + String: Borrow, + { + if name == "".to_owned().borrow() { + return None + } + + self.variables.get(name) + } + + /// # Panics + /// If any variable name is empty. + pub fn add_function(&mut self, name: String, source: Rc, variables: Vec) { + if variables.iter().any(|s| s.is_empty()) { + panic!() + } + + self.functions.insert(name, (source, variables)); + } + + /// This uses dynamic scope for bindings. + /// # Errors + /// If `name` is unbound or has been called with the wrong number of arguments. + pub fn call_function, T: ?Sized + Hash + Eq>( + &mut self, + name: &T, + args: I, + ) -> Option + where + String: Borrow, + ::IntoIter: ExactSizeIterator, + { + let (term, vars) = self.functions.get(name)?; + let args_iter = args.into_iter(); + + if vars.len() != args_iter.len() { + None + } else { + let mut old = Vec::new(); + for (var, value) in vars.clone().into_iter().zip(args_iter) { + if let Some((old_name, old_value)) = self.variables.remove_entry(var.borrow()) { + let mut indices = Vec::new(); + + for (index, binding) in self.bindings.iter_mut().enumerate() { + if *binding == old_name { + indices.push((index, mem::take(binding))); + } + } + + old.push((old_name, old_value, indices)) + } + + self.variables.insert(var, value); + } + + let res = Some(term.clone().convert(self)); + + for (name, value, indices) in old { + for (index, binding) in indices { + self.bindings[index] = binding + } + + self.variables.insert(name, value); + } + + res + } } } -pub trait Convert { - fn convert(self, context: &Context) -> Term; +pub trait Convert: std::fmt::Debug { + fn convert(&self, context: &mut Context) -> Term; } diff --git a/src/main.rs b/src/main.rs index 0d4ab99..c180302 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,7 @@ fn main() { .and_then(|_| syn::parse_str(&input).map_err(|e| Box::new(e) as Box)) .and_then(|nibble: Expression| { nibble - .convert(&Context::new()) + .convert(&mut Context::new()) .well_typed(&mut FlastContext::new()) .map_err(|e| Box::new(e) as Box) }) diff --git a/src/nibble/mod.rs b/src/nibble/mod.rs index 3fad97b..bb14f7b 100644 --- a/src/nibble/mod.rs +++ b/src/nibble/mod.rs @@ -10,26 +10,37 @@ use syn::{ token, Result, Token, }; +const PREFER_SUBST_OVER_CALL: bool = true; + pub type Epsilon = Token![_]; impl Convert for Epsilon { - fn convert(self, _: &Context) -> ast::Term { - ast::Term::Epsilon(self) + fn convert(&self, _: &mut Context) -> ast::Term { + ast::Term::Epsilon(*self) } } type Ident = syn::Ident; impl Convert for Ident { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; let name = self.to_string(); - if let Some(variable) = context.get(&name) { - Term::Variable(ast::Variable::new(self, variable)) + if let Some(variable) = context.get_binding(&name) { + Term::Variable(ast::Variable::new(self.clone(), variable)) + } else if PREFER_SUBST_OVER_CALL { + if let Some(term) = context.get_variable(&name) { + term.clone() + } else if let Some(term) = context.call_function(&name, std::iter::empty()) { + term + } else { + let span = self.span(); + Term::Call(ast::Call::new(self.clone(), Vec::new(), span)) + } } else { let span = self.span(); - Term::Call(ast::Call::new(self, Vec::new(), span)) + Term::Call(ast::Call::new(self.clone(), Vec::new(), span)) } } } @@ -37,8 +48,8 @@ impl Convert for Ident { type Literal = syn::LitStr; impl Convert for Literal { - fn convert(self, _context: &Context) -> ast::Term { - ast::Term::Literal(self) + fn convert(&self, _: &mut Context) -> ast::Term { + ast::Term::Literal(self.clone()) } } @@ -51,7 +62,10 @@ pub struct Call { impl Call { fn span(&self) -> Span { - self.name.span().join(self.paren_tok.span).unwrap_or_else(Span::call_site) + self.name + .span() + .join(self.paren_tok.span) + .unwrap_or_else(Span::call_site) } } @@ -70,20 +84,25 @@ impl Parse for Call { } impl Convert for Call { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; - let span = self.span(); - Term::Call(ast::Call::new( - self.name, - self.args - .into_pairs() - .map(|pair| match pair { - Pair::Punctuated(t, _) => t.convert(context), - Pair::End(t) => t.convert(context), - }) - .collect(), - span, - )) + let args: Vec<_> = self + .args + .pairs() + .map(|pair| match pair { + Pair::Punctuated(t, _) => t.convert(context), + Pair::End(t) => t.convert(context), + }) + .collect(); + if PREFER_SUBST_OVER_CALL { + if let Some(term) = context.call_function(&self.name.to_string(), args.clone()) { + term + } else { + Term::Call(ast::Call::new(self.name.clone(), args, self.span())) + } + } else { + Term::Call(ast::Call::new(self.name.clone(), args, self.span())) + } } } @@ -97,7 +116,10 @@ pub struct Fix { impl Fix { fn span(&self) -> Span { - self.bracket_token.span.join(self.paren_token.span).unwrap_or_else(Span::call_site) + self.bracket_token + .span + .join(self.paren_token.span) + .unwrap_or_else(Span::call_site) } } @@ -119,14 +141,14 @@ impl Parse for Fix { } impl Convert for Fix { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; let span = self.span(); - let expr = self.expr; + let expr = &self.expr; let arg_name = self.arg.to_string(); Term::Fix(ast::Fix::new( - self.arg, - context.push(arg_name, |c| expr.convert(c)), + self.arg.clone(), + context.push_binding(arg_name, |c| expr.convert(c)), span, )) } @@ -139,7 +161,7 @@ pub struct ParenExpression { } impl ParenExpression { - fn span(&self) -> Span { + pub fn span(&self) -> Span { self.paren_tok.span } } @@ -154,7 +176,7 @@ impl Parse for ParenExpression { } impl Convert for ParenExpression { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { self.expr.convert(context) } } @@ -170,7 +192,7 @@ pub enum Term { } impl Term { - fn span(&self) -> Span { + pub fn span(&self) -> Span { match self { Self::Epsilon(eps) => eps.span, Self::Ident(ident) => ident.span(), @@ -209,7 +231,7 @@ impl Parse for Term { args, }) }) - }else { + } else { Ok(Self::Ident(name)) } } else { @@ -219,7 +241,7 @@ impl Parse for Term { } impl Convert for Term { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { match self { Self::Epsilon(e) => e.convert(context), Self::Ident(i) => i.convert(context), @@ -237,7 +259,7 @@ pub struct Cat { } impl Cat { - fn span(&self) -> Span { + pub fn span(&self) -> Span { let mut pairs = self.terms.pairs(); let mut span = pairs.next().and_then(|pair| match pair { Pair::Punctuated(term, punct) => term.span().join(punct.span), @@ -246,7 +268,9 @@ impl Cat { for pair in pairs { span = span.and_then(|span| match pair { - Pair::Punctuated(term, punct) => span.join(term.span()).and_then(|s| s.join(punct.span)), + Pair::Punctuated(term, punct) => { + span.join(term.span()).and_then(|s| s.join(punct.span)) + } Pair::End(term) => span.join(term.span()), }) } @@ -264,24 +288,23 @@ impl Parse for Cat { } impl Convert for Cat { - fn convert(self, context: &Context) -> ast::Term { + fn convert(&self, context: &mut Context) -> ast::Term { use ast::Term; - let mut iter = self.terms.into_pairs(); + let mut iter = self.terms.pairs(); let init = match iter.next().unwrap() { Pair::Punctuated(t, p) => Ok((t.convert(context), p)), Pair::End(t) => Err(t.convert(context)), }; - iter.fold(init, |term, pair|{ + iter.fold(init, |term, pair| { let (fst, punct) = term.unwrap(); match pair { Pair::Punctuated(t, p) => { - Ok((Term::Cat(ast::Cat::new(fst, punct, t.convert(context))), p)) - } - Pair::End(t) => { - Err(Term::Cat(ast::Cat::new(fst, punct, t.convert(context)))) + Ok((Term::Cat(ast::Cat::new(fst, *punct, t.convert(context))), p)) } + Pair::End(t) => Err(Term::Cat(ast::Cat::new(fst, *punct, t.convert(context)))), } - }).unwrap_err() + }) + .unwrap_err() } } @@ -291,7 +314,7 @@ pub struct Alt { } impl Alt { - fn span(&self) -> Span { + pub fn span(&self) -> Span { let mut pairs = self.cats.pairs(); let mut span = pairs.next().and_then(|pair| match pair { Pair::Punctuated(cat, punct) => cat.span().join(punct.span), @@ -300,7 +323,9 @@ impl Alt { for pair in pairs { span = span.and_then(|span| match pair { - Pair::Punctuated(cat, punct) => span.join(cat.span()).and_then(|s| s.join(punct.span)), + Pair::Punctuated(cat, punct) => { + span.join(cat.span()).and_then(|s| s.join(punct.span)) + } Pair::End(cat) => span.join(cat.span()), }) } @@ -318,24 +343,24 @@ impl Parse for Alt { } impl Convert for Alt { - fn convert(self, context: &Context) -> ast::Term { - use ast::Term; - let mut iter = self.cats.into_pairs(); - let init = match iter.next().unwrap() { - Pair::Punctuated(t, p) => Ok((t.convert(context), p)), - Pair::End(t) => Err(t.convert(context)), - }; - iter.fold(init, |cat, pair|{ - let (left, punct) = cat.unwrap(); - match pair { - Pair::Punctuated(t, p) => { - Ok((Term::Alt(ast::Alt::new(left, punct, t.convert(context))), p)) - } - Pair::End(t) => { - Err(Term::Alt(ast::Alt::new(left, punct, t.convert(context)))) - } - } - }).unwrap_err() + fn convert(&self, context: &mut Context) -> ast::Term { + use ast::Term; + let mut iter = self.cats.pairs(); + let init = match iter.next().unwrap() { + Pair::Punctuated(t, p) => Ok((t.convert(context), p)), + Pair::End(t) => Err(t.convert(context)), + }; + iter.fold(init, |cat, pair| { + let (left, punct) = cat.unwrap(); + match pair { + Pair::Punctuated(t, p) => Ok(( + Term::Alt(ast::Alt::new(left, *punct, t.convert(context))), + p, + )), + Pair::End(t) => Err(Term::Alt(ast::Alt::new(left, *punct, t.convert(context)))), + } + }) + .unwrap_err() } } -- cgit v1.2.3