summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast/convert.rs115
-rw-r--r--src/main.rs2
-rw-r--r--src/nibble/mod.rs149
3 files changed, 193 insertions, 73 deletions
diff --git a/src/ast/convert.rs b/src/ast/convert.rs
index c27c75f..f828a16 100644
--- a/src/ast/convert.rs
+++ b/src/ast/convert.rs
@@ -1,8 +1,16 @@
+use std::borrow::Borrow;
+use std::collections::HashMap;
+use std::hash::Hash;
+use std::mem;
+use std::rc::Rc;
+
use super::Term;
-#[derive(Clone, Debug, Default)]
+#[derive(Debug, Default)]
pub struct Context {
- vars: Vec<String>,
+ bindings: Vec<String>,
+ variables: HashMap<String, Term>,
+ functions: HashMap<String, (Rc<dyn Convert>, Vec<String>)>,
}
impl Context {
@@ -19,6 +27,9 @@ impl Context {
Self::default()
}
+ /// # Errors
+ /// Returns [`None`] if `name.is_empty()` or if `name` is unbound.
+ ///
/// # Examples
/// ```
/// use chomp::ast::convert::Context;
@@ -36,10 +47,14 @@ impl Context {
///
/// assert_eq!(context.get("x"), None);
/// ```
- pub fn get<T: ?Sized + PartialEq<str>>(&self, name: &T) -> Option<usize> {
- let mut iter = self.vars.iter();
+ pub fn get_binding<T: ?Sized + PartialEq<str>>(&self, name: &T) -> Option<usize> {
+ let mut iter = self.bindings.iter();
let mut pos = 0;
+ if name == "" {
+ return None;
+ }
+
while let Some(var) = iter.next_back() {
if T::eq(&name, &var) {
return Some(pos);
@@ -51,6 +66,9 @@ impl Context {
None
}
+ /// # Panics
+ /// If `name.is_empty()`.
+ ///
/// # Examples
/// ```
/// use chomp::ast::convert::Context;
@@ -64,13 +82,90 @@ impl Context {
///
/// assert_eq!(context.get("x"), None);
/// ```
- pub fn push<F: FnOnce(&Self) -> T, T>(&self, var: String, f: F) -> T {
- let mut context = self.clone();
- context.vars.push(var);
- f(&context)
+ pub fn push_binding<F: FnOnce(&mut Self) -> T, T>(&mut self, name: String, f: F) -> T {
+ if name.is_empty() {
+ panic!()
+ }
+
+ self.bindings.push(name);
+ let res = f(self);
+ self.bindings.pop();
+ res
+ }
+
+ /// # Errors
+ /// If `name == "".to_owned().borrow()` or `name` is unbound.
+ pub fn get_variable<T: ?Sized + Hash + Eq>(&self, name: &T) -> Option<&Term>
+ where
+ String: Borrow<T>,
+ {
+ if name == "".to_owned().borrow() {
+ return None
+ }
+
+ self.variables.get(name)
+ }
+
+ /// # Panics
+ /// If any variable name is empty.
+ pub fn add_function(&mut self, name: String, source: Rc<dyn Convert>, variables: Vec<String>) {
+ if variables.iter().any(|s| s.is_empty()) {
+ panic!()
+ }
+
+ self.functions.insert(name, (source, variables));
+ }
+
+ /// This uses dynamic scope for bindings.
+ /// # Errors
+ /// If `name` is unbound or has been called with the wrong number of arguments.
+ pub fn call_function<I: IntoIterator<Item = Term>, T: ?Sized + Hash + Eq>(
+ &mut self,
+ name: &T,
+ args: I,
+ ) -> Option<Term>
+ where
+ String: Borrow<T>,
+ <I as IntoIterator>::IntoIter: ExactSizeIterator,
+ {
+ let (term, vars) = self.functions.get(name)?;
+ let args_iter = args.into_iter();
+
+ if vars.len() != args_iter.len() {
+ None
+ } else {
+ let mut old = Vec::new();
+ for (var, value) in vars.clone().into_iter().zip(args_iter) {
+ if let Some((old_name, old_value)) = self.variables.remove_entry(var.borrow()) {
+ let mut indices = Vec::new();
+
+ for (index, binding) in self.bindings.iter_mut().enumerate() {
+ if *binding == old_name {
+ indices.push((index, mem::take(binding)));
+ }
+ }
+
+ old.push((old_name, old_value, indices))
+ }
+
+ self.variables.insert(var, value);
+ }
+
+ let res = Some(term.clone().convert(self));
+
+ for (name, value, indices) in old {
+ for (index, binding) in indices {
+ self.bindings[index] = binding
+ }
+
+ self.variables.insert(name, value);
+ }
+
+ res
+ }
}
}
-pub trait Convert {
- fn convert(self, context: &Context) -> Term;
+pub trait Convert: std::fmt::Debug {
+ fn convert(&self, context: &mut Context) -> Term;
}
diff --git a/src/main.rs b/src/main.rs
index 0d4ab99..c180302 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -21,7 +21,7 @@ fn main() {
.and_then(|_| syn::parse_str(&input).map_err(|e| Box::new(e) as Box<dyn Error>))
.and_then(|nibble: Expression| {
nibble
- .convert(&Context::new())
+ .convert(&mut Context::new())
.well_typed(&mut FlastContext::new())
.map_err(|e| Box::new(e) as Box<dyn Error>)
})
diff --git a/src/nibble/mod.rs b/src/nibble/mod.rs
index 3fad97b..bb14f7b 100644
--- a/src/nibble/mod.rs
+++ b/src/nibble/mod.rs
@@ -10,26 +10,37 @@ use syn::{
token, Result, Token,
};
+const PREFER_SUBST_OVER_CALL: bool = true;
+
pub type Epsilon = Token![_];
impl Convert for Epsilon {
- fn convert(self, _: &Context) -> ast::Term {
- ast::Term::Epsilon(self)
+ fn convert(&self, _: &mut Context) -> ast::Term {
+ ast::Term::Epsilon(*self)
}
}
type Ident = syn::Ident;
impl Convert for Ident {
- fn convert(self, context: &Context) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let name = self.to_string();
- if let Some(variable) = context.get(&name) {
- Term::Variable(ast::Variable::new(self, variable))
+ if let Some(variable) = context.get_binding(&name) {
+ Term::Variable(ast::Variable::new(self.clone(), variable))
+ } else if PREFER_SUBST_OVER_CALL {
+ if let Some(term) = context.get_variable(&name) {
+ term.clone()
+ } else if let Some(term) = context.call_function(&name, std::iter::empty()) {
+ term
+ } else {
+ let span = self.span();
+ Term::Call(ast::Call::new(self.clone(), Vec::new(), span))
+ }
} else {
let span = self.span();
- Term::Call(ast::Call::new(self, Vec::new(), span))
+ Term::Call(ast::Call::new(self.clone(), Vec::new(), span))
}
}
}
@@ -37,8 +48,8 @@ impl Convert for Ident {
type Literal = syn::LitStr;
impl Convert for Literal {
- fn convert(self, _context: &Context) -> ast::Term {
- ast::Term::Literal(self)
+ fn convert(&self, _: &mut Context) -> ast::Term {
+ ast::Term::Literal(self.clone())
}
}
@@ -51,7 +62,10 @@ pub struct Call {
impl Call {
fn span(&self) -> Span {
- self.name.span().join(self.paren_tok.span).unwrap_or_else(Span::call_site)
+ self.name
+ .span()
+ .join(self.paren_tok.span)
+ .unwrap_or_else(Span::call_site)
}
}
@@ -70,20 +84,25 @@ impl Parse for Call {
}
impl Convert for Call {
- fn convert(self, context: &Context) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
- let span = self.span();
- Term::Call(ast::Call::new(
- self.name,
- self.args
- .into_pairs()
- .map(|pair| match pair {
- Pair::Punctuated(t, _) => t.convert(context),
- Pair::End(t) => t.convert(context),
- })
- .collect(),
- span,
- ))
+ let args: Vec<_> = self
+ .args
+ .pairs()
+ .map(|pair| match pair {
+ Pair::Punctuated(t, _) => t.convert(context),
+ Pair::End(t) => t.convert(context),
+ })
+ .collect();
+ if PREFER_SUBST_OVER_CALL {
+ if let Some(term) = context.call_function(&self.name.to_string(), args.clone()) {
+ term
+ } else {
+ Term::Call(ast::Call::new(self.name.clone(), args, self.span()))
+ }
+ } else {
+ Term::Call(ast::Call::new(self.name.clone(), args, self.span()))
+ }
}
}
@@ -97,7 +116,10 @@ pub struct Fix {
impl Fix {
fn span(&self) -> Span {
- self.bracket_token.span.join(self.paren_token.span).unwrap_or_else(Span::call_site)
+ self.bracket_token
+ .span
+ .join(self.paren_token.span)
+ .unwrap_or_else(Span::call_site)
}
}
@@ -119,14 +141,14 @@ impl Parse for Fix {
}
impl Convert for Fix {
- fn convert(self, context: &Context) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let span = self.span();
- let expr = self.expr;
+ let expr = &self.expr;
let arg_name = self.arg.to_string();
Term::Fix(ast::Fix::new(
- self.arg,
- context.push(arg_name, |c| expr.convert(c)),
+ self.arg.clone(),
+ context.push_binding(arg_name, |c| expr.convert(c)),
span,
))
}
@@ -139,7 +161,7 @@ pub struct ParenExpression {
}
impl ParenExpression {
- fn span(&self) -> Span {
+ pub fn span(&self) -> Span {
self.paren_tok.span
}
}
@@ -154,7 +176,7 @@ impl Parse for ParenExpression {
}
impl Convert for ParenExpression {
- fn convert(self, context: &Context) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
self.expr.convert(context)
}
}
@@ -170,7 +192,7 @@ pub enum Term {
}
impl Term {
- fn span(&self) -> Span {
+ pub fn span(&self) -> Span {
match self {
Self::Epsilon(eps) => eps.span,
Self::Ident(ident) => ident.span(),
@@ -209,7 +231,7 @@ impl Parse for Term {
args,
})
})
- }else {
+ } else {
Ok(Self::Ident(name))
}
} else {
@@ -219,7 +241,7 @@ impl Parse for Term {
}
impl Convert for Term {
- fn convert(self, context: &Context) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
match self {
Self::Epsilon(e) => e.convert(context),
Self::Ident(i) => i.convert(context),
@@ -237,7 +259,7 @@ pub struct Cat {
}
impl Cat {
- fn span(&self) -> Span {
+ pub fn span(&self) -> Span {
let mut pairs = self.terms.pairs();
let mut span = pairs.next().and_then(|pair| match pair {
Pair::Punctuated(term, punct) => term.span().join(punct.span),
@@ -246,7 +268,9 @@ impl Cat {
for pair in pairs {
span = span.and_then(|span| match pair {
- Pair::Punctuated(term, punct) => span.join(term.span()).and_then(|s| s.join(punct.span)),
+ Pair::Punctuated(term, punct) => {
+ span.join(term.span()).and_then(|s| s.join(punct.span))
+ }
Pair::End(term) => span.join(term.span()),
})
}
@@ -264,24 +288,23 @@ impl Parse for Cat {
}
impl Convert for Cat {
- fn convert(self, context: &Context) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
- let mut iter = self.terms.into_pairs();
+ let mut iter = self.terms.pairs();
let init = match iter.next().unwrap() {
Pair::Punctuated(t, p) => Ok((t.convert(context), p)),
Pair::End(t) => Err(t.convert(context)),
};
- iter.fold(init, |term, pair|{
+ iter.fold(init, |term, pair| {
let (fst, punct) = term.unwrap();
match pair {
Pair::Punctuated(t, p) => {
- Ok((Term::Cat(ast::Cat::new(fst, punct, t.convert(context))), p))
- }
- Pair::End(t) => {
- Err(Term::Cat(ast::Cat::new(fst, punct, t.convert(context))))
+ Ok((Term::Cat(ast::Cat::new(fst, *punct, t.convert(context))), p))
}
+ Pair::End(t) => Err(Term::Cat(ast::Cat::new(fst, *punct, t.convert(context)))),
}
- }).unwrap_err()
+ })
+ .unwrap_err()
}
}
@@ -291,7 +314,7 @@ pub struct Alt {
}
impl Alt {
- fn span(&self) -> Span {
+ pub fn span(&self) -> Span {
let mut pairs = self.cats.pairs();
let mut span = pairs.next().and_then(|pair| match pair {
Pair::Punctuated(cat, punct) => cat.span().join(punct.span),
@@ -300,7 +323,9 @@ impl Alt {
for pair in pairs {
span = span.and_then(|span| match pair {
- Pair::Punctuated(cat, punct) => span.join(cat.span()).and_then(|s| s.join(punct.span)),
+ Pair::Punctuated(cat, punct) => {
+ span.join(cat.span()).and_then(|s| s.join(punct.span))
+ }
Pair::End(cat) => span.join(cat.span()),
})
}
@@ -318,24 +343,24 @@ impl Parse for Alt {
}
impl Convert for Alt {
- fn convert(self, context: &Context) -> ast::Term {
- use ast::Term;
- let mut iter = self.cats.into_pairs();
- let init = match iter.next().unwrap() {
- Pair::Punctuated(t, p) => Ok((t.convert(context), p)),
- Pair::End(t) => Err(t.convert(context)),
- };
- iter.fold(init, |cat, pair|{
- let (left, punct) = cat.unwrap();
- match pair {
- Pair::Punctuated(t, p) => {
- Ok((Term::Alt(ast::Alt::new(left, punct, t.convert(context))), p))
- }
- Pair::End(t) => {
- Err(Term::Alt(ast::Alt::new(left, punct, t.convert(context))))
- }
- }
- }).unwrap_err()
+ fn convert(&self, context: &mut Context) -> ast::Term {
+ use ast::Term;
+ let mut iter = self.cats.pairs();
+ let init = match iter.next().unwrap() {
+ Pair::Punctuated(t, p) => Ok((t.convert(context), p)),
+ Pair::End(t) => Err(t.convert(context)),
+ };
+ iter.fold(init, |cat, pair| {
+ let (left, punct) = cat.unwrap();
+ match pair {
+ Pair::Punctuated(t, p) => Ok((
+ Term::Alt(ast::Alt::new(left, *punct, t.convert(context))),
+ p,
+ )),
+ Pair::End(t) => Err(Term::Alt(ast::Alt::new(left, *punct, t.convert(context)))),
+ }
+ })
+ .unwrap_err()
}
}