summaryrefslogtreecommitdiff
path: root/src/nibble
diff options
context:
space:
mode:
Diffstat (limited to 'src/nibble')
-rw-r--r--src/nibble/mod.rs141
1 files changed, 57 insertions, 84 deletions
diff --git a/src/nibble/mod.rs b/src/nibble/mod.rs
index 427569d..59788e6 100644
--- a/src/nibble/mod.rs
+++ b/src/nibble/mod.rs
@@ -1,6 +1,6 @@
+use std::collections::HashMap;
use std::rc::Rc;
-use crate::ast::convert::ConvertMode;
use crate::ast::{
self,
convert::{Context, Convert},
@@ -16,7 +16,7 @@ use syn::{
pub type Epsilon = Token![_];
impl Convert for Epsilon {
- fn convert(&self, _: &mut Context, _: ConvertMode) -> ast::Term {
+ fn convert(&self, _: &mut Context) -> ast::Term {
ast::Term::Epsilon(*self)
}
}
@@ -24,31 +24,17 @@ impl Convert for Epsilon {
type Ident = syn::Ident;
impl Convert for Ident {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let name = self.to_string();
- if let Some(variable) = context.get_binding(&name) {
+ if let Some(binding) = context.get_binding(&name) {
+ Term::Binding(ast::Variable::new(self.clone(), binding))
+ } else if let Some(variable) = context.get_variable(&name) {
Term::Variable(ast::Variable::new(self.clone(), variable))
} else {
- match mode {
- ConvertMode::NoSubstitution => {
- let span = self.span();
- Term::Call(ast::Call::new(self.clone(), Vec::new(), span))
- }
- ConvertMode::WithSubstitution => {
- if let Some(term) = context.get_variable(&name) {
- term.clone()
- } else if let Some(term) =
- context.call_function(&name, std::iter::empty(), mode)
- {
- term
- } else {
- let span = self.span();
- Term::Call(ast::Call::new(self.clone(), Vec::new(), span))
- }
- }
- }
+ let span = self.span();
+ Term::Call(ast::Call::new(self.clone(), Vec::new(), span))
}
}
}
@@ -56,7 +42,7 @@ impl Convert for Ident {
type Literal = syn::LitStr;
impl Convert for Literal {
- fn convert(&self, _: &mut Context, _: ConvertMode) -> ast::Term {
+ fn convert(&self, _: &mut Context) -> ast::Term {
ast::Term::Literal(self.clone())
}
}
@@ -71,6 +57,10 @@ impl<T> ArgList<T> {
fn span(&self) -> Span {
self.paren_token.span
}
+
+ fn len(&self) -> usize {
+ self.args.len()
+ }
}
impl<T> IntoIterator for ArgList<T> {
@@ -115,28 +105,15 @@ impl Parse for Call {
}
impl Convert for Call {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let args = self
.args
.clone()
.into_iter()
- .map(|arg| arg.convert(context, mode))
+ .map(|arg| arg.convert(context))
.collect::<Vec<_>>();
- match mode {
- ConvertMode::NoSubstitution => {
- Term::Call(ast::Call::new(self.name.clone(), args, self.span()))
- }
- ConvertMode::WithSubstitution => {
- if let Some(term) =
- context.call_function(&self.name.to_string(), args.clone(), mode)
- {
- term
- } else {
- Term::Call(ast::Call::new(self.name.clone(), args, self.span()))
- }
- }
- }
+ Term::Call(ast::Call::new(self.name.clone(), args, self.span()))
}
}
@@ -175,14 +152,14 @@ impl Parse for Fix {
}
impl Convert for Fix {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let span = self.span();
let expr = &self.expr;
let arg_name = self.arg.to_string();
Term::Fix(ast::Fix::new(
self.arg.clone(),
- context.push_binding(arg_name, |ctx| expr.convert(ctx, mode)),
+ context.with_binding(arg_name, |ctx| expr.convert(ctx)),
span,
))
}
@@ -210,8 +187,8 @@ impl Parse for ParenExpression {
}
impl Convert for ParenExpression {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
- self.expr.convert(context, mode)
+ fn convert(&self, context: &mut Context) -> ast::Term {
+ self.expr.convert(context)
}
}
@@ -267,14 +244,14 @@ impl Parse for Term {
}
impl Convert for Term {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
match self {
- Self::Epsilon(e) => e.convert(context, mode),
- Self::Ident(i) => i.convert(context, mode),
- Self::Literal(l) => l.convert(context, mode),
- Self::Call(c) => c.convert(context, mode),
- Self::Fix(f) => f.convert(context, mode),
- Self::Parens(e) => e.convert(context, mode),
+ Self::Epsilon(e) => e.convert(context),
+ Self::Ident(i) => i.convert(context),
+ Self::Literal(l) => l.convert(context),
+ Self::Call(c) => c.convert(context),
+ Self::Fix(f) => f.convert(context),
+ Self::Parens(e) => e.convert(context),
}
}
}
@@ -314,25 +291,20 @@ impl Parse for Cat {
}
impl Convert for Cat {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let mut iter = self.terms.pairs();
let init = match iter.next().unwrap() {
- Pair::Punctuated(t, p) => Ok((t.convert(context, mode), p)),
- Pair::End(t) => Err(t.convert(context, mode)),
+ Pair::Punctuated(t, p) => Ok((t.convert(context), p)),
+ Pair::End(t) => Err(t.convert(context)),
};
iter.fold(init, |term, pair| {
let (fst, punct) = term.unwrap();
match pair {
- Pair::Punctuated(t, p) => Ok((
- Term::Cat(ast::Cat::new(fst, *punct, t.convert(context, mode))),
- p,
- )),
- Pair::End(t) => Err(Term::Cat(ast::Cat::new(
- fst,
- *punct,
- t.convert(context, mode),
- ))),
+ Pair::Punctuated(t, p) => {
+ Ok((Term::Cat(ast::Cat::new(fst, *punct, t.convert(context))), p))
+ }
+ Pair::End(t) => Err(Term::Cat(ast::Cat::new(fst, *punct, t.convert(context)))),
}
})
.unwrap_err()
@@ -374,25 +346,21 @@ impl Parse for Alt {
}
impl Convert for Alt {
- fn convert(&self, context: &mut Context, mode: ConvertMode) -> ast::Term {
+ fn convert(&self, context: &mut Context) -> ast::Term {
use ast::Term;
let mut iter = self.cats.pairs();
let init = match iter.next().unwrap() {
- Pair::Punctuated(t, p) => Ok((t.convert(context, mode), p)),
- Pair::End(t) => Err(t.convert(context, mode)),
+ Pair::Punctuated(t, p) => Ok((t.convert(context), p)),
+ Pair::End(t) => Err(t.convert(context)),
};
iter.fold(init, |cat, pair| {
let (left, punct) = cat.unwrap();
match pair {
Pair::Punctuated(t, p) => Ok((
- Term::Alt(ast::Alt::new(left, *punct, t.convert(context, mode))),
+ Term::Alt(ast::Alt::new(left, *punct, t.convert(context))),
p,
)),
- Pair::End(t) => Err(Term::Alt(ast::Alt::new(
- left,
- *punct,
- t.convert(context, mode),
- ))),
+ Pair::End(t) => Err(Term::Alt(ast::Alt::new(left, *punct, t.convert(context)))),
}
})
.unwrap_err()
@@ -457,20 +425,25 @@ pub struct File {
}
impl File {
- pub fn convert_with_substitution(self) -> ast::Term {
+ /// Returns function list and the goal. The function list consists of an
+ /// [`Ident`], the converted [`ast::Term`] and the number of arguments.
+ pub fn convert(self) -> (Vec<(Ident, ast::Term, usize)>, ast::Term) {
let mut context = Context::new();
- for statement in self.lets {
- context.add_function(
- statement.name.to_string(),
- Rc::new(statement.expr),
- statement
- .args
- .map(|args| args.into_iter().map(|arg| arg.to_string()).collect())
- .unwrap_or_default(),
- );
- }
-
- self.goal.expr.convert(&mut context, ConvertMode::WithSubstitution)
+ let map = self
+ .lets
+ .into_iter()
+ .map(|stmt| {
+ let count = stmt.args.as_ref().map(ArgList::len).unwrap_or_default();
+ context.set_variables(
+ stmt.args
+ .into_iter()
+ .flat_map(|args| args.into_iter().map(|arg| arg.to_string())),
+ );
+ (stmt.name, stmt.expr.convert(&mut context), count)
+ })
+ .collect();
+ let goal = self.goal.expr.convert(&mut context);
+ (map, goal)
}
}