commit
						36aeb10da9
					
				| @ -0,0 +1 @@ | ||||
| /target | ||||
| @ -0,0 +1,7 @@ | ||||
| # This file is automatically @generated by Cargo. | ||||
| # It is not intended for manual editing. | ||||
| version = 4 | ||||
| 
 | ||||
| [[package]] | ||||
| name = "stlc_type_inference" | ||||
| version = "0.1.0" | ||||
| @ -0,0 +1,6 @@ | ||||
| [package] | ||||
| name = "stlc_type_inference" | ||||
| version = "0.1.0" | ||||
| edition = "2024" | ||||
| 
 | ||||
| [dependencies] | ||||
| @ -0,0 +1,100 @@ | ||||
| #[cfg(test)] | ||||
| mod test; | ||||
| 
 | ||||
| use std::{collections::HashMap, rc::Rc}; | ||||
| 
 | ||||
| use crate::{Ast, Constant, Ident, Type}; | ||||
| 
 | ||||
| impl Ast { | ||||
|     pub fn beta_reduce(self) -> Ast { | ||||
|         match self { | ||||
|             Ast::Application(lhs, rhs) => match *lhs { | ||||
|                 Ast::Abstraction(var, _, ast) => ast.subst(var, *rhs), | ||||
|                 lhs => Ast::Application(Box::new(lhs), rhs), | ||||
|             }, | ||||
|             t => t, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn subst(self, var: Ident, subst: Ast) -> Ast { | ||||
|         match self { | ||||
|             Ast::Abstraction(var1, typ, ast) => { | ||||
|                 if var != var1 { | ||||
|                     Ast::Abstraction(var1, typ, Box::new(ast.subst(var, subst))) | ||||
|                 } else { | ||||
|                     Ast::Abstraction(var1, typ, ast) | ||||
|                 } | ||||
|             } | ||||
|             Ast::Application(lhs, rhs) => Ast::Application( | ||||
|                 Box::new(lhs.subst(var.clone(), subst.clone())), | ||||
|                 Box::new(rhs.subst(var, subst)), | ||||
|             ), | ||||
|             Ast::Variable(v) if v == var => subst, | ||||
|             Ast::Variable(v) => Ast::Variable(v), | ||||
|             Ast::Constant(constant) => Ast::Constant(constant), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Ast { | ||||
|     pub fn to_de_brujin(self) -> DeBrujinAst { | ||||
|         self.to_de_brujin_inter(Rc::new(HashMap::new())) | ||||
|     } | ||||
| 
 | ||||
|     fn to_de_brujin_inter(self, mut gamma: Rc<HashMap<String, usize>>) -> DeBrujinAst { | ||||
|         match self { | ||||
|             Ast::Abstraction(var, _, ast) => { | ||||
|                 let gamma_ref = Rc::make_mut(&mut gamma); | ||||
|                 gamma_ref.values_mut().for_each(|v| *v += 1); | ||||
|                 gamma_ref.insert(var, 1); | ||||
| 
 | ||||
|                 DeBrujinAst::Abstraction(Box::new(ast.to_de_brujin_inter(gamma))) | ||||
|             } | ||||
|             Ast::Application(lhs, rhs) => DeBrujinAst::Application( | ||||
|                 Box::new(lhs.to_de_brujin_inter(gamma.clone())), | ||||
|                 Box::new(rhs.to_de_brujin_inter(gamma)), | ||||
|             ), | ||||
|             Ast::Variable(v) => { | ||||
|                 if let Some(c) = gamma.get(&v) { | ||||
|                     DeBrujinAst::BoundVariable(*c) | ||||
|                 } else { | ||||
|                     DeBrujinAst::FreeVariable(v) | ||||
|                 } | ||||
|             } | ||||
|             Ast::Constant(constant) => DeBrujinAst::Constant(constant), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| pub enum DeBrujinAst { | ||||
|     Abstraction(Box<DeBrujinAst>),                   // \:1.2
 | ||||
|     Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1
 | ||||
|     FreeVariable(String),                            // x
 | ||||
|     BoundVariable(usize),                            // 1
 | ||||
|     Constant(Constant),                              // true | false | n
 | ||||
| } | ||||
| 
 | ||||
| impl DeBrujinAst { | ||||
|     pub fn beta_reduce(self) -> DeBrujinAst { | ||||
|         match self { | ||||
|             DeBrujinAst::Application(lhs, rhs) => match *lhs { | ||||
|                 DeBrujinAst::Abstraction(ast) => ast.subst_bound(1, *rhs), | ||||
|                 lhs => DeBrujinAst::Application(Box::new(lhs), rhs), | ||||
|             }, | ||||
|             a => a, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn subst_bound(self, depth: usize, subst: DeBrujinAst) -> DeBrujinAst { | ||||
|         match self { | ||||
|             DeBrujinAst::Abstraction(ast) => ast.subst_bound(depth + 1, subst), | ||||
|             DeBrujinAst::Application(lhs, rhs) => DeBrujinAst::Application( | ||||
|                 Box::new(lhs.subst_bound(depth, subst.clone())), | ||||
|                 Box::new(rhs.subst_bound(depth, subst)), | ||||
|             ), | ||||
|             DeBrujinAst::BoundVariable(n) if n == depth => subst, | ||||
|             a => a, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,68 @@ | ||||
| use crate::{Ast, Constant, PrimitiveType, Type, exec::DeBrujinAst as DBAst}; | ||||
| 
 | ||||
| #[test] | ||||
| fn beta_reduce() { | ||||
|     let input = Ast::Application( | ||||
|         Box::new(Ast::Abstraction( | ||||
|             "x".to_string(), | ||||
|             Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), | ||||
|             Box::new(Ast::Application( | ||||
|                 Box::new(Ast::Variable("x".to_string())), | ||||
|                 Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|             )), | ||||
|         )), | ||||
|         Box::new(Ast::Variable("y".to_string())), | ||||
|     ); | ||||
|     let reduced = input.beta_reduce(); | ||||
|     assert_eq!( | ||||
|         reduced, | ||||
|         Ast::Application( | ||||
|             Box::new(Ast::Variable("y".to_string())), | ||||
|             Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|         ), | ||||
|     ) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn to_de_brujin_ast_simple() { | ||||
|     let input = Ast::Abstraction( | ||||
|         "x".to_string(), | ||||
|         PrimitiveType::Nat.into(), | ||||
|         Box::new(Ast::Abstraction( | ||||
|             "x".to_string(), | ||||
|             PrimitiveType::Nat.into(), | ||||
|             Box::new(Ast::Variable("x".to_string())), | ||||
|         )), | ||||
|     ); | ||||
|     let de_brujin = input.to_de_brujin(); | ||||
|     assert_eq!( | ||||
|         de_brujin, | ||||
|         DBAst::Abstraction(Box::new(DBAst::Abstraction(Box::new( | ||||
|             DBAst::BoundVariable(1) | ||||
|         )))) | ||||
|     ) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn de_brujin_beta_reduce() { | ||||
|     let input = Ast::Application( | ||||
|         Box::new(Ast::Abstraction( | ||||
|             "x".to_string(), | ||||
|             Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), | ||||
|             Box::new(Ast::Application( | ||||
|                 Box::new(Ast::Variable("x".to_string())), | ||||
|                 Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|             )), | ||||
|         )), | ||||
|         Box::new(Ast::Variable("y".to_string())), | ||||
|     ); | ||||
|     let dbast = input.to_de_brujin(); | ||||
|     let reduced = dbast.beta_reduce(); | ||||
|     assert_eq!( | ||||
|         reduced, | ||||
|         DBAst::Application( | ||||
|             Box::new(DBAst::FreeVariable("y".to_string())), | ||||
|             Box::new(DBAst::Constant(Constant::Nat(5))), | ||||
|         ), | ||||
|     ) | ||||
| } | ||||
| @ -0,0 +1,37 @@ | ||||
| use std::{collections::HashMap, error::Error, rc::Rc}; | ||||
| 
 | ||||
| use crate::{Ast, Constant, Ident, PrimitiveType, Type}; | ||||
| 
 | ||||
| #[derive(Debug)] | ||||
| pub enum InferError { | ||||
|     NotAFunction, | ||||
|     MismatchedType, | ||||
|     NotInContext, | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod test; | ||||
| 
 | ||||
| pub fn infer_type(mut gamma: Rc<HashMap<Ident, Type>>, ast: Ast) -> Result<Type, InferError> { | ||||
|     match ast { | ||||
|         Ast::Abstraction(arg, arg_type, ast) => { | ||||
|             Rc::make_mut(&mut gamma).insert(arg, arg_type.clone()); | ||||
|             let out_type = infer_type(gamma, *ast)?; | ||||
|             Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type))) | ||||
|         } | ||||
|         Ast::Application(left, right) => { | ||||
|             let left_type = infer_type(gamma.clone(), *left)?; | ||||
|             let Type::Arrow(in_type, out_type) = left_type else { | ||||
|                 return Err(InferError::NotAFunction); | ||||
|             }; | ||||
|             let right_type = infer_type(gamma, *right)?; | ||||
|             if *in_type != right_type { | ||||
|                 return Err(InferError::MismatchedType); | ||||
|             } | ||||
|             Ok(*out_type) | ||||
|         } | ||||
|         Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext), | ||||
|         Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)), | ||||
|         Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)), | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,49 @@ | ||||
| use std::{collections::HashMap, rc::Rc}; | ||||
| 
 | ||||
| use crate::{Ast, Constant, PrimitiveType, Type}; | ||||
| 
 | ||||
| use super::infer_type; | ||||
| 
 | ||||
| #[test] | ||||
| fn infer_id_type() { | ||||
|     let ast = Ast::Abstraction( | ||||
|         "x".to_string(), | ||||
|         Type::Primitive(PrimitiveType::Nat), | ||||
|         Box::new(Ast::Variable("x".to_string())), | ||||
|     ); | ||||
| 
 | ||||
|     let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap(); | ||||
|     assert_eq!( | ||||
|         infered, | ||||
|         Type::Arrow( | ||||
|             Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|             Box::new(Type::Primitive(PrimitiveType::Nat)) | ||||
|         ) | ||||
|     ) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn infer_addition_result_type() { | ||||
|     let ast = Ast::Application( | ||||
|         Box::new(Ast::Application( | ||||
|             Box::new(Ast::Variable("add".to_string())), | ||||
|             Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|         )), | ||||
|         Box::new(Ast::Constant(Constant::Nat(7))), | ||||
|     ); | ||||
| 
 | ||||
|     let mut gamma = HashMap::new(); | ||||
|     gamma.insert( | ||||
|         "add".to_string(), | ||||
|         Type::Arrow( | ||||
|             Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|             Box::new(Type::Arrow( | ||||
|                 Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|                 Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|             )), | ||||
|         ), | ||||
|     ); | ||||
| 
 | ||||
|     let infered = infer_type(Rc::new(gamma), ast).unwrap(); | ||||
|     assert_eq!(infered, Type::Primitive(PrimitiveType::Nat)); | ||||
| } | ||||
| @ -0,0 +1,85 @@ | ||||
| #![allow(unused)] | ||||
| 
 | ||||
| use std::fmt::Display; | ||||
| 
 | ||||
| mod exec; | ||||
| mod inference; | ||||
| mod parse; | ||||
| 
 | ||||
| type Ident = String; | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| enum PrimitiveType { | ||||
|     Nat, | ||||
|     Bool, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| enum Type { | ||||
|     Primitive(PrimitiveType),    // Bool | Nat
 | ||||
|     Arrow(Box<Type>, Box<Type>), // 0 -> 1
 | ||||
| } | ||||
| 
 | ||||
| impl From<PrimitiveType> for Type { | ||||
|     fn from(value: PrimitiveType) -> Self { | ||||
|         Type::Primitive(value) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Type { | ||||
|     fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self { | ||||
|         Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| enum Constant { | ||||
|     Nat(usize), | ||||
|     Bool(bool), | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| enum Ast { | ||||
|     Abstraction(Ident, Type, Box<Ast>), // \0:1.2
 | ||||
|     Application(Box<Ast>, Box<Ast>),    // 0 1
 | ||||
|     Variable(Ident),                    // x
 | ||||
|     Constant(Constant),                 // true | false | n
 | ||||
| } | ||||
| 
 | ||||
| impl Display for Ast { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"), | ||||
|             Ast::Application(lhs, rhs) => write!(f, "(lhs rhs)"), | ||||
|             Ast::Variable(v) => write!(f, "{v}"), | ||||
|             Ast::Constant(constant) => write!(f, "{constant}"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for Type { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), | ||||
|             Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for PrimitiveType { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             PrimitiveType::Nat => write!(f, "Nat"), | ||||
|             PrimitiveType::Bool => write!(f, "Bool"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for Constant { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             Constant::Nat(n) => write!(f, "{n}"), | ||||
|             Constant::Bool(b) => write!(f, "{b}"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1 @@ | ||||
| fn main() {} | ||||
| @ -0,0 +1,182 @@ | ||||
| use sexpr::{Sexpr, parse_string}; | ||||
| 
 | ||||
| use crate::{Ast, Constant, PrimitiveType, Type}; | ||||
| 
 | ||||
| mod sexpr; | ||||
| #[cfg(test)] | ||||
| mod test; | ||||
| 
 | ||||
| #[derive(Debug)] | ||||
| pub enum ParseError { | ||||
|     UnexpectedParenClose, | ||||
|     UnexpectedEof, | ||||
|     TrailingTokens, | ||||
|     TrailingExpr, | ||||
|     ToplevelSymbol, | ||||
|     InvalidSymbol, | ||||
|     UnexpectedEndOfList, | ||||
|     UnknownType, | ||||
|     ExpectedList, | ||||
|     ExpectedSymbol, | ||||
|     ExpectedLambda, | ||||
|     ExpectedIdent, | ||||
|     ExpectedType, | ||||
|     NotAType, | ||||
|     ExpectedBody, | ||||
|     ExpectedArrow, | ||||
|     ExpectedOneOf(Vec<String>, String), | ||||
| } | ||||
| 
 | ||||
| fn expect_symbol(ast: Option<Sexpr>) -> Result<String, ParseError> { | ||||
|     match ast { | ||||
|         Some(Sexpr::Symbol(s)) => Ok(s), | ||||
|         Some(l) => Err(ParseError::ExpectedSymbol), | ||||
|         None => Err(ParseError::ExpectedSymbol), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn expect_ident(ast: Option<Sexpr>) -> Result<String, ParseError> { | ||||
|     let sym = expect_symbol(ast)?; | ||||
|     if is_ident(&sym) { | ||||
|         Ok(sym) | ||||
|     } else { | ||||
|         Err(ParseError::ExpectedIdent) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn expect_list(ast: Option<Sexpr>) -> Result<Vec<Sexpr>, ParseError> { | ||||
|     match ast { | ||||
|         Some(Sexpr::List(l)) => Ok(l), | ||||
|         Some(l) => Err(ParseError::ExpectedList), | ||||
|         None => Err(ParseError::ExpectedList), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn expect_one_of<T>(options: &[T], item: String) -> Result<String, ParseError> | ||||
| where | ||||
|     T: PartialEq<String> + Into<String> + Clone, | ||||
| { | ||||
|     if options.iter().find(|e| **e == item).is_some() { | ||||
|         Ok(item) | ||||
|     } else { | ||||
|         Err(ParseError::ExpectedOneOf( | ||||
|             options | ||||
|                 .iter() | ||||
|                 .map(|t| Into::<String>::into(t.clone())) | ||||
|                 .collect(), | ||||
|             item, | ||||
|         )) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn expect_empty<T, I: Iterator<Item = T>>(mut iter: I) -> Result<(), ParseError> { | ||||
|     match iter.next() { | ||||
|         Some(_) => Err(ParseError::TrailingTokens), | ||||
|         None => Ok(()), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub fn parse(input: &str) -> Result<Ast, ParseError> { | ||||
|     let ast = parse_string(input)?; | ||||
|     match ast { | ||||
|         Sexpr::Symbol(s) => parse_symbol(s), | ||||
|         list => parse_intern(list), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn parse_intern(ast: Sexpr) -> Result<Ast, ParseError> { | ||||
|     match ast { | ||||
|         Sexpr::Symbol(s) => parse_symbol(s), | ||||
|         Sexpr::List(sexprs) => { | ||||
|             let mut iter = sexprs.into_iter(); | ||||
|             match iter.next() { | ||||
|                 Some(Sexpr::Symbol(sym)) => { | ||||
|                     if sym == "\\" { | ||||
|                         let bind = expect_list(iter.next())?; | ||||
|                         let mut bind = bind.into_iter(); | ||||
|                         let ident = expect_ident(bind.next())?; | ||||
|                         let typ = parse_type(&bind.next().ok_or(ParseError::ExpectedType)?)?; | ||||
|                         expect_empty(bind)?; | ||||
|                         let ast = Ast::Abstraction( | ||||
|                             ident, | ||||
|                             typ, | ||||
|                             Box::new(parse_intern(iter.next().ok_or(ParseError::ExpectedBody)?)?), | ||||
|                         ); | ||||
|                         expect_empty(iter)?; | ||||
|                         Ok(ast) | ||||
|                     } else { | ||||
|                         let ast = parse_symbol(sym)?; | ||||
|                         if let Some(e) = iter.next() { | ||||
|                             let rhs = parse_intern(e)?; | ||||
|                             expect_empty(iter)?; | ||||
|                             Ok(Ast::Application(Box::new(ast), Box::new(rhs))) | ||||
|                         } else { | ||||
|                             Ok(ast) | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 Some(app_left) => { | ||||
|                     if let Some(app_right) = iter.next() { | ||||
|                         expect_empty(iter)?; | ||||
|                         Ok(Ast::Application( | ||||
|                             Box::new(parse_intern(app_left)?), | ||||
|                             // Make it back into an Sexpr so we can feed it to parse intern
 | ||||
|                             Box::new(parse_intern(app_right)?), | ||||
|                         )) | ||||
|                     } else { | ||||
|                         Err(ParseError::UnexpectedEndOfList) | ||||
|                     } | ||||
|                 } | ||||
|                 None => Err(ParseError::UnexpectedEndOfList), | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn parse_symbol(s: String) -> Result<Ast, ParseError> { | ||||
|     if let Ok(n) = s.parse::<usize>() { | ||||
|         Ok(Ast::Constant(Constant::Nat(n))) | ||||
|     } else if let Ok(b) = s.parse::<bool>() { | ||||
|         Ok(Ast::Constant(Constant::Bool(b))) | ||||
|     } else if is_ident(&s) { | ||||
|         Ok(Ast::Variable(s)) | ||||
|     } else { | ||||
|         Err(ParseError::InvalidSymbol) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn is_ident(s: &str) -> bool { | ||||
|     s.starts_with(|c: char| c.is_alphabetic()) && s.chars().all(|c| c.is_alphanumeric()) | ||||
| } | ||||
| 
 | ||||
| fn parse_type(ast: &Sexpr) -> Result<Type, ParseError> { | ||||
|     match ast { | ||||
|         Sexpr::Symbol(s) => parse_prim_type(s), | ||||
|         Sexpr::List(sexprs) => parse_type_list(sexprs), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn parse_type_list(typ: &[Sexpr]) -> Result<Type, ParseError> { | ||||
|     let Some(t) = typ.get(0) else { todo!() }; | ||||
| 
 | ||||
|     if typ.get(1).is_some() { | ||||
|         let arr = expect_symbol(typ.get(1).cloned())?; | ||||
|         if arr != "->" { | ||||
|             return Err(ParseError::ExpectedArrow); | ||||
|         } | ||||
|         Ok(Type::Arrow( | ||||
|             Box::new(parse_type(t)?), | ||||
|             Box::new(parse_type_list(&typ[2..])?), | ||||
|         )) | ||||
|     } else { | ||||
|         parse_type(t) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| fn parse_prim_type(typ: &str) -> Result<Type, ParseError> { | ||||
|     match typ { | ||||
|         "Bool" => Ok(Type::Primitive(PrimitiveType::Bool)), | ||||
|         "Nat" => Ok(Type::Primitive(PrimitiveType::Nat)), | ||||
|         _ => Err(ParseError::UnknownType), | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,92 @@ | ||||
| use std::iter::Peekable; | ||||
| use std::ops::{Deref, RangeInclusive}; | ||||
| use std::usize; | ||||
| use std::vec::IntoIter; | ||||
| 
 | ||||
| use super::ParseError; | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Clone)] | ||||
| pub enum Token { | ||||
|     LeftParen, | ||||
|     RightParen, | ||||
|     Symbol(String), | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, PartialEq, Clone)] | ||||
| pub enum Sexpr { | ||||
|     Symbol(String), | ||||
|     List(Vec<Sexpr>), | ||||
| } | ||||
| 
 | ||||
| impl Sexpr { | ||||
|     pub fn symbol(self) -> Option<String> { | ||||
|         match self { | ||||
|             Sexpr::Symbol(item) => Some(item), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn list(self) -> Option<Vec<Sexpr>> { | ||||
|         match self { | ||||
|             Sexpr::List(item) => Some(item), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub fn tokenize(input: &str) -> Vec<Token> { | ||||
|     let mut tokens = Vec::new(); | ||||
|     // let mut chars = input.chars().peekable();
 | ||||
|     let mut chars = input.chars().peekable(); | ||||
|     while let Some(c) = chars.next() { | ||||
|         match c { | ||||
|             '(' => tokens.push(Token::LeftParen), | ||||
|             ')' => tokens.push(Token::RightParen), | ||||
|             _ if c.is_whitespace() => (), | ||||
|             _ => { | ||||
|                 let mut symbol = c.to_string(); | ||||
|                 while let Some(c) = chars.peek() { | ||||
|                     if c.is_whitespace() || *c == '(' || *c == ')' { | ||||
|                         break; | ||||
|                     } | ||||
|                     symbol.push(*c); | ||||
|                     chars.next(); | ||||
|                 } | ||||
|                 tokens.push(Token::Symbol(symbol)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     tokens | ||||
| } | ||||
| 
 | ||||
| fn parse_expr(tokens: &mut Peekable<IntoIter<Token>>) -> Result<Sexpr, ParseError> { | ||||
|     match tokens.next() { | ||||
|         Some(Token::LeftParen) => { | ||||
|             let mut list = Vec::new(); | ||||
|             while !matches!(tokens.peek(), Some(Token::RightParen,)) { | ||||
|                 list.push(parse_expr(tokens)?); | ||||
|             } | ||||
|             let Some(Token::RightParen) = tokens.next() else { | ||||
|                 unreachable!() | ||||
|             }; | ||||
|             Ok(Sexpr::List(list)) | ||||
|         } | ||||
|         Some(Token::RightParen) => Err(ParseError::UnexpectedParenClose), | ||||
|         Some(Token::Symbol(s)) => Ok(Sexpr::Symbol(s)), | ||||
|         None => Err(ParseError::UnexpectedEof), | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub fn parse(tokens: Vec<Token>) -> Result<Sexpr, ParseError> { | ||||
|     let mut tokens = tokens.into_iter().peekable(); | ||||
|     let ast = parse_expr(&mut tokens)?; | ||||
|     if tokens.peek().is_some() { | ||||
|         return Err(ParseError::TrailingTokens); | ||||
|     }; | ||||
|     Ok(ast) | ||||
| } | ||||
| 
 | ||||
| pub fn parse_string(src: &str) -> Result<Sexpr, ParseError> { | ||||
|     let tokens = tokenize(src); | ||||
|     parse(tokens) | ||||
| } | ||||
| @ -0,0 +1,93 @@ | ||||
| use crate::{ | ||||
|     Ast, Constant, PrimitiveType, Type, | ||||
|     parse::sexpr::{Sexpr, parse_string}, | ||||
| }; | ||||
| 
 | ||||
| use super::{parse, parse_type}; | ||||
| 
 | ||||
| #[test] | ||||
| fn parse_to_sexpr() { | ||||
|     let input = "((\\x:Nat.x) (5))"; | ||||
|     let parsed = parse_string(input).unwrap(); | ||||
|     assert_eq!( | ||||
|         parsed, | ||||
|         Sexpr::List(vec![ | ||||
|             Sexpr::List(vec![Sexpr::Symbol("\\x:Nat.x".to_string())]), | ||||
|             Sexpr::List(vec![Sexpr::Symbol("5".to_string())]) | ||||
|         ]) | ||||
|     ); | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn parse_prim_type() { | ||||
|     let input = Sexpr::Symbol("Nat".to_string()); | ||||
|     let parsed = parse_type(&input).unwrap(); | ||||
|     assert_eq!(parsed, Type::Primitive(PrimitiveType::Nat)) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn parse_simpl_arr_type() { | ||||
|     let input = Sexpr::List(vec![ | ||||
|         Sexpr::Symbol("Nat".to_string()), | ||||
|         Sexpr::Symbol("->".to_string()), | ||||
|         Sexpr::Symbol("Nat".to_string()), | ||||
|     ]); | ||||
|     let parsed = parse_type(&input).unwrap(); | ||||
|     assert_eq!(parsed, Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn parse_apply_arr_type() { | ||||
|     let input = Sexpr::List(vec![ | ||||
|         Sexpr::List(vec![ | ||||
|             Sexpr::Symbol("Nat".to_string()), | ||||
|             Sexpr::Symbol("->".to_string()), | ||||
|             Sexpr::Symbol("Nat".to_string()), | ||||
|         ]), | ||||
|         Sexpr::Symbol("->".to_string()), | ||||
|         Sexpr::Symbol("Nat".to_string()), | ||||
|         Sexpr::Symbol("->".to_string()), | ||||
|         Sexpr::Symbol("Nat".to_string()), | ||||
|     ]); | ||||
|     let parsed = parse_type(&input).unwrap(); | ||||
|     assert_eq!( | ||||
|         parsed, | ||||
|         Type::arrow( | ||||
|             Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), | ||||
|             Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat) | ||||
|         ) | ||||
|     ) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn parse_abstraction() { | ||||
|     let input = "(\\ (x (Nat -> Nat)) (x 5))"; | ||||
|     let parsed = parse(input).unwrap(); | ||||
|     assert_eq!( | ||||
|         parsed, | ||||
|         Ast::Abstraction( | ||||
|             "x".to_string(), | ||||
|             Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), | ||||
|             Box::new(Ast::Application( | ||||
|                 Box::new(Ast::Variable("x".to_string())), | ||||
|                 Box::new(Ast::Constant(Constant::Nat(5))) | ||||
|             )) | ||||
|         ) | ||||
|     ) | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn parse_application() { | ||||
|     let input = "((add 5) 6)"; | ||||
|     let parsed = parse(input).unwrap(); | ||||
|     assert_eq!( | ||||
|         parsed, | ||||
|         Ast::Application( | ||||
|             Box::new(Ast::Application( | ||||
|                 Box::new(Ast::Variable("add".to_string())), | ||||
|                 Box::new(Ast::Constant(Constant::Nat(5))) | ||||
|             )), | ||||
|             Box::new(Ast::Constant(Constant::Nat(6))) | ||||
|         ) | ||||
|     ) | ||||
| } | ||||
					Loading…
					
					
				
		Reference in new issue