commit 36aeb10da93502233b99586f54fbf75ec2919b40 Author: Avery Date: Thu Mar 13 21:15:42 2025 +0100 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..f4b76a5 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "stlc_type_inference" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..6e03258 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "stlc_type_inference" +version = "0.1.0" +edition = "2024" + +[dependencies] diff --git a/src/exec/mod.rs b/src/exec/mod.rs new file mode 100644 index 0000000..bb7ea8f --- /dev/null +++ b/src/exec/mod.rs @@ -0,0 +1,100 @@ +#[cfg(test)] +mod test; + +use std::{collections::HashMap, rc::Rc}; + +use crate::{Ast, Constant, Ident, Type}; + +impl Ast { + pub fn beta_reduce(self) -> Ast { + match self { + Ast::Application(lhs, rhs) => match *lhs { + Ast::Abstraction(var, _, ast) => ast.subst(var, *rhs), + lhs => Ast::Application(Box::new(lhs), rhs), + }, + t => t, + } + } + + fn subst(self, var: Ident, subst: Ast) -> Ast { + match self { + Ast::Abstraction(var1, typ, ast) => { + if var != var1 { + Ast::Abstraction(var1, typ, Box::new(ast.subst(var, subst))) + } else { + Ast::Abstraction(var1, typ, ast) + } + } + Ast::Application(lhs, rhs) => Ast::Application( + Box::new(lhs.subst(var.clone(), subst.clone())), + Box::new(rhs.subst(var, subst)), + ), + Ast::Variable(v) if v == var => subst, + Ast::Variable(v) => Ast::Variable(v), + Ast::Constant(constant) => Ast::Constant(constant), + } + } +} + +impl Ast { + pub fn to_de_brujin(self) -> DeBrujinAst { + self.to_de_brujin_inter(Rc::new(HashMap::new())) + } + + fn to_de_brujin_inter(self, mut gamma: Rc>) -> DeBrujinAst { + match self { + Ast::Abstraction(var, _, ast) => { + let gamma_ref = Rc::make_mut(&mut gamma); + gamma_ref.values_mut().for_each(|v| *v += 1); + gamma_ref.insert(var, 1); + + DeBrujinAst::Abstraction(Box::new(ast.to_de_brujin_inter(gamma))) + } + Ast::Application(lhs, rhs) => DeBrujinAst::Application( + Box::new(lhs.to_de_brujin_inter(gamma.clone())), + Box::new(rhs.to_de_brujin_inter(gamma)), + ), + Ast::Variable(v) => { + if let Some(c) = gamma.get(&v) { + DeBrujinAst::BoundVariable(*c) + } else { + DeBrujinAst::FreeVariable(v) + } + } + Ast::Constant(constant) => DeBrujinAst::Constant(constant), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DeBrujinAst { + Abstraction(Box), // \:1.2 + Application(Box, Box), // 0 1 + FreeVariable(String), // x + BoundVariable(usize), // 1 + Constant(Constant), // true | false | n +} + +impl DeBrujinAst { + pub fn beta_reduce(self) -> DeBrujinAst { + match self { + DeBrujinAst::Application(lhs, rhs) => match *lhs { + DeBrujinAst::Abstraction(ast) => ast.subst_bound(1, *rhs), + lhs => DeBrujinAst::Application(Box::new(lhs), rhs), + }, + a => a, + } + } + + fn subst_bound(self, depth: usize, subst: DeBrujinAst) -> DeBrujinAst { + match self { + DeBrujinAst::Abstraction(ast) => ast.subst_bound(depth + 1, subst), + DeBrujinAst::Application(lhs, rhs) => DeBrujinAst::Application( + Box::new(lhs.subst_bound(depth, subst.clone())), + Box::new(rhs.subst_bound(depth, subst)), + ), + DeBrujinAst::BoundVariable(n) if n == depth => subst, + a => a, + } + } +} diff --git a/src/exec/test.rs b/src/exec/test.rs new file mode 100644 index 0000000..bfd8256 --- /dev/null +++ b/src/exec/test.rs @@ -0,0 +1,68 @@ +use crate::{Ast, Constant, PrimitiveType, Type, exec::DeBrujinAst as DBAst}; + +#[test] +fn beta_reduce() { + let input = Ast::Application( + Box::new(Ast::Abstraction( + "x".to_string(), + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + Box::new(Ast::Application( + Box::new(Ast::Variable("x".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + )), + Box::new(Ast::Variable("y".to_string())), + ); + let reduced = input.beta_reduce(); + assert_eq!( + reduced, + Ast::Application( + Box::new(Ast::Variable("y".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + ), + ) +} + +#[test] +fn to_de_brujin_ast_simple() { + let input = Ast::Abstraction( + "x".to_string(), + PrimitiveType::Nat.into(), + Box::new(Ast::Abstraction( + "x".to_string(), + PrimitiveType::Nat.into(), + Box::new(Ast::Variable("x".to_string())), + )), + ); + let de_brujin = input.to_de_brujin(); + assert_eq!( + de_brujin, + DBAst::Abstraction(Box::new(DBAst::Abstraction(Box::new( + DBAst::BoundVariable(1) + )))) + ) +} + +#[test] +fn de_brujin_beta_reduce() { + let input = Ast::Application( + Box::new(Ast::Abstraction( + "x".to_string(), + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + Box::new(Ast::Application( + Box::new(Ast::Variable("x".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + )), + Box::new(Ast::Variable("y".to_string())), + ); + let dbast = input.to_de_brujin(); + let reduced = dbast.beta_reduce(); + assert_eq!( + reduced, + DBAst::Application( + Box::new(DBAst::FreeVariable("y".to_string())), + Box::new(DBAst::Constant(Constant::Nat(5))), + ), + ) +} diff --git a/src/inference/mod.rs b/src/inference/mod.rs new file mode 100644 index 0000000..37c4fdf --- /dev/null +++ b/src/inference/mod.rs @@ -0,0 +1,37 @@ +use std::{collections::HashMap, error::Error, rc::Rc}; + +use crate::{Ast, Constant, Ident, PrimitiveType, Type}; + +#[derive(Debug)] +pub enum InferError { + NotAFunction, + MismatchedType, + NotInContext, +} + +#[cfg(test)] +mod test; + +pub fn infer_type(mut gamma: Rc>, ast: Ast) -> Result { + match ast { + Ast::Abstraction(arg, arg_type, ast) => { + Rc::make_mut(&mut gamma).insert(arg, arg_type.clone()); + let out_type = infer_type(gamma, *ast)?; + Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type))) + } + Ast::Application(left, right) => { + let left_type = infer_type(gamma.clone(), *left)?; + let Type::Arrow(in_type, out_type) = left_type else { + return Err(InferError::NotAFunction); + }; + let right_type = infer_type(gamma, *right)?; + if *in_type != right_type { + return Err(InferError::MismatchedType); + } + Ok(*out_type) + } + Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext), + Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)), + Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)), + } +} diff --git a/src/inference/test.rs b/src/inference/test.rs new file mode 100644 index 0000000..37fe4eb --- /dev/null +++ b/src/inference/test.rs @@ -0,0 +1,49 @@ +use std::{collections::HashMap, rc::Rc}; + +use crate::{Ast, Constant, PrimitiveType, Type}; + +use super::infer_type; + +#[test] +fn infer_id_type() { + let ast = Ast::Abstraction( + "x".to_string(), + Type::Primitive(PrimitiveType::Nat), + Box::new(Ast::Variable("x".to_string())), + ); + + let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap(); + assert_eq!( + infered, + Type::Arrow( + Box::new(Type::Primitive(PrimitiveType::Nat)), + Box::new(Type::Primitive(PrimitiveType::Nat)) + ) + ) +} + +#[test] +fn infer_addition_result_type() { + let ast = Ast::Application( + Box::new(Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + Box::new(Ast::Constant(Constant::Nat(7))), + ); + + let mut gamma = HashMap::new(); + gamma.insert( + "add".to_string(), + Type::Arrow( + Box::new(Type::Primitive(PrimitiveType::Nat)), + Box::new(Type::Arrow( + Box::new(Type::Primitive(PrimitiveType::Nat)), + Box::new(Type::Primitive(PrimitiveType::Nat)), + )), + ), + ); + + let infered = infer_type(Rc::new(gamma), ast).unwrap(); + assert_eq!(infered, Type::Primitive(PrimitiveType::Nat)); +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..3ae9143 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,85 @@ +#![allow(unused)] + +use std::fmt::Display; + +mod exec; +mod inference; +mod parse; + +type Ident = String; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum PrimitiveType { + Nat, + Bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Type { + Primitive(PrimitiveType), // Bool | Nat + Arrow(Box, Box), // 0 -> 1 +} + +impl From for Type { + fn from(value: PrimitiveType) -> Self { + Type::Primitive(value) + } +} + +impl Type { + fn arrow, T2: Into>(t1: T1, t2: T2) -> Self { + Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Constant { + Nat(usize), + Bool(bool), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Ast { + Abstraction(Ident, Type, Box), // \0:1.2 + Application(Box, Box), // 0 1 + Variable(Ident), // x + Constant(Constant), // true | false | n +} + +impl Display for Ast { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"), + Ast::Application(lhs, rhs) => write!(f, "(lhs rhs)"), + Ast::Variable(v) => write!(f, "{v}"), + Ast::Constant(constant) => write!(f, "{constant}"), + } + } +} + +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), + Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), + } + } +} + +impl Display for PrimitiveType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PrimitiveType::Nat => write!(f, "Nat"), + PrimitiveType::Bool => write!(f, "Bool"), + } + } +} + +impl Display for Constant { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Constant::Nat(n) => write!(f, "{n}"), + Constant::Bool(b) => write!(f, "{b}"), + } + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..f328e4d --- /dev/null +++ b/src/main.rs @@ -0,0 +1 @@ +fn main() {} diff --git a/src/parse/mod.rs b/src/parse/mod.rs new file mode 100644 index 0000000..dec9d64 --- /dev/null +++ b/src/parse/mod.rs @@ -0,0 +1,182 @@ +use sexpr::{Sexpr, parse_string}; + +use crate::{Ast, Constant, PrimitiveType, Type}; + +mod sexpr; +#[cfg(test)] +mod test; + +#[derive(Debug)] +pub enum ParseError { + UnexpectedParenClose, + UnexpectedEof, + TrailingTokens, + TrailingExpr, + ToplevelSymbol, + InvalidSymbol, + UnexpectedEndOfList, + UnknownType, + ExpectedList, + ExpectedSymbol, + ExpectedLambda, + ExpectedIdent, + ExpectedType, + NotAType, + ExpectedBody, + ExpectedArrow, + ExpectedOneOf(Vec, String), +} + +fn expect_symbol(ast: Option) -> Result { + match ast { + Some(Sexpr::Symbol(s)) => Ok(s), + Some(l) => Err(ParseError::ExpectedSymbol), + None => Err(ParseError::ExpectedSymbol), + } +} + +fn expect_ident(ast: Option) -> Result { + let sym = expect_symbol(ast)?; + if is_ident(&sym) { + Ok(sym) + } else { + Err(ParseError::ExpectedIdent) + } +} + +fn expect_list(ast: Option) -> Result, ParseError> { + match ast { + Some(Sexpr::List(l)) => Ok(l), + Some(l) => Err(ParseError::ExpectedList), + None => Err(ParseError::ExpectedList), + } +} + +fn expect_one_of(options: &[T], item: String) -> Result +where + T: PartialEq + Into + Clone, +{ + if options.iter().find(|e| **e == item).is_some() { + Ok(item) + } else { + Err(ParseError::ExpectedOneOf( + options + .iter() + .map(|t| Into::::into(t.clone())) + .collect(), + item, + )) + } +} + +fn expect_empty>(mut iter: I) -> Result<(), ParseError> { + match iter.next() { + Some(_) => Err(ParseError::TrailingTokens), + None => Ok(()), + } +} + +pub fn parse(input: &str) -> Result { + let ast = parse_string(input)?; + match ast { + Sexpr::Symbol(s) => parse_symbol(s), + list => parse_intern(list), + } +} + +fn parse_intern(ast: Sexpr) -> Result { + match ast { + Sexpr::Symbol(s) => parse_symbol(s), + Sexpr::List(sexprs) => { + let mut iter = sexprs.into_iter(); + match iter.next() { + Some(Sexpr::Symbol(sym)) => { + if sym == "\\" { + let bind = expect_list(iter.next())?; + let mut bind = bind.into_iter(); + let ident = expect_ident(bind.next())?; + let typ = parse_type(&bind.next().ok_or(ParseError::ExpectedType)?)?; + expect_empty(bind)?; + let ast = Ast::Abstraction( + ident, + typ, + Box::new(parse_intern(iter.next().ok_or(ParseError::ExpectedBody)?)?), + ); + expect_empty(iter)?; + Ok(ast) + } else { + let ast = parse_symbol(sym)?; + if let Some(e) = iter.next() { + let rhs = parse_intern(e)?; + expect_empty(iter)?; + Ok(Ast::Application(Box::new(ast), Box::new(rhs))) + } else { + Ok(ast) + } + } + } + Some(app_left) => { + if let Some(app_right) = iter.next() { + expect_empty(iter)?; + Ok(Ast::Application( + Box::new(parse_intern(app_left)?), + // Make it back into an Sexpr so we can feed it to parse intern + Box::new(parse_intern(app_right)?), + )) + } else { + Err(ParseError::UnexpectedEndOfList) + } + } + None => Err(ParseError::UnexpectedEndOfList), + } + } + } +} + +fn parse_symbol(s: String) -> Result { + if let Ok(n) = s.parse::() { + Ok(Ast::Constant(Constant::Nat(n))) + } else if let Ok(b) = s.parse::() { + Ok(Ast::Constant(Constant::Bool(b))) + } else if is_ident(&s) { + Ok(Ast::Variable(s)) + } else { + Err(ParseError::InvalidSymbol) + } +} + +fn is_ident(s: &str) -> bool { + s.starts_with(|c: char| c.is_alphabetic()) && s.chars().all(|c| c.is_alphanumeric()) +} + +fn parse_type(ast: &Sexpr) -> Result { + match ast { + Sexpr::Symbol(s) => parse_prim_type(s), + Sexpr::List(sexprs) => parse_type_list(sexprs), + } +} + +fn parse_type_list(typ: &[Sexpr]) -> Result { + let Some(t) = typ.get(0) else { todo!() }; + + if typ.get(1).is_some() { + let arr = expect_symbol(typ.get(1).cloned())?; + if arr != "->" { + return Err(ParseError::ExpectedArrow); + } + Ok(Type::Arrow( + Box::new(parse_type(t)?), + Box::new(parse_type_list(&typ[2..])?), + )) + } else { + parse_type(t) + } +} + +fn parse_prim_type(typ: &str) -> Result { + match typ { + "Bool" => Ok(Type::Primitive(PrimitiveType::Bool)), + "Nat" => Ok(Type::Primitive(PrimitiveType::Nat)), + _ => Err(ParseError::UnknownType), + } +} diff --git a/src/parse/sexpr.rs b/src/parse/sexpr.rs new file mode 100644 index 0000000..933cd10 --- /dev/null +++ b/src/parse/sexpr.rs @@ -0,0 +1,92 @@ +use std::iter::Peekable; +use std::ops::{Deref, RangeInclusive}; +use std::usize; +use std::vec::IntoIter; + +use super::ParseError; + +#[derive(Debug, PartialEq, Clone)] +pub enum Token { + LeftParen, + RightParen, + Symbol(String), +} + +#[derive(Debug, PartialEq, Clone)] +pub enum Sexpr { + Symbol(String), + List(Vec), +} + +impl Sexpr { + pub fn symbol(self) -> Option { + match self { + Sexpr::Symbol(item) => Some(item), + _ => None, + } + } + + pub fn list(self) -> Option> { + match self { + Sexpr::List(item) => Some(item), + _ => None, + } + } +} + +pub fn tokenize(input: &str) -> Vec { + let mut tokens = Vec::new(); + // let mut chars = input.chars().peekable(); + let mut chars = input.chars().peekable(); + while let Some(c) = chars.next() { + match c { + '(' => tokens.push(Token::LeftParen), + ')' => tokens.push(Token::RightParen), + _ if c.is_whitespace() => (), + _ => { + let mut symbol = c.to_string(); + while let Some(c) = chars.peek() { + if c.is_whitespace() || *c == '(' || *c == ')' { + break; + } + symbol.push(*c); + chars.next(); + } + tokens.push(Token::Symbol(symbol)); + } + } + } + tokens +} + +fn parse_expr(tokens: &mut Peekable>) -> Result { + match tokens.next() { + Some(Token::LeftParen) => { + let mut list = Vec::new(); + while !matches!(tokens.peek(), Some(Token::RightParen,)) { + list.push(parse_expr(tokens)?); + } + let Some(Token::RightParen) = tokens.next() else { + unreachable!() + }; + Ok(Sexpr::List(list)) + } + Some(Token::RightParen) => Err(ParseError::UnexpectedParenClose), + Some(Token::Symbol(s)) => Ok(Sexpr::Symbol(s)), + None => Err(ParseError::UnexpectedEof), + } +} + +pub fn parse(tokens: Vec) -> Result { + let mut tokens = tokens.into_iter().peekable(); + let ast = parse_expr(&mut tokens)?; + if tokens.peek().is_some() { + return Err(ParseError::TrailingTokens); + }; + Ok(ast) +} + +pub fn parse_string(src: &str) -> Result { + let tokens = tokenize(src); + parse(tokens) +} diff --git a/src/parse/test.rs b/src/parse/test.rs new file mode 100644 index 0000000..61f2ef7 --- /dev/null +++ b/src/parse/test.rs @@ -0,0 +1,93 @@ +use crate::{ + Ast, Constant, PrimitiveType, Type, + parse::sexpr::{Sexpr, parse_string}, +}; + +use super::{parse, parse_type}; + +#[test] +fn parse_to_sexpr() { + let input = "((\\x:Nat.x) (5))"; + let parsed = parse_string(input).unwrap(); + assert_eq!( + parsed, + Sexpr::List(vec![ + Sexpr::List(vec![Sexpr::Symbol("\\x:Nat.x".to_string())]), + Sexpr::List(vec![Sexpr::Symbol("5".to_string())]) + ]) + ); +} + +#[test] +fn parse_prim_type() { + let input = Sexpr::Symbol("Nat".to_string()); + let parsed = parse_type(&input).unwrap(); + assert_eq!(parsed, Type::Primitive(PrimitiveType::Nat)) +} + +#[test] +fn parse_simpl_arr_type() { + let input = Sexpr::List(vec![ + Sexpr::Symbol("Nat".to_string()), + Sexpr::Symbol("->".to_string()), + Sexpr::Symbol("Nat".to_string()), + ]); + let parsed = parse_type(&input).unwrap(); + assert_eq!(parsed, Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) +} + +#[test] +fn parse_apply_arr_type() { + let input = Sexpr::List(vec![ + Sexpr::List(vec![ + Sexpr::Symbol("Nat".to_string()), + Sexpr::Symbol("->".to_string()), + Sexpr::Symbol("Nat".to_string()), + ]), + Sexpr::Symbol("->".to_string()), + Sexpr::Symbol("Nat".to_string()), + Sexpr::Symbol("->".to_string()), + Sexpr::Symbol("Nat".to_string()), + ]); + let parsed = parse_type(&input).unwrap(); + assert_eq!( + parsed, + Type::arrow( + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat) + ) + ) +} + +#[test] +fn parse_abstraction() { + let input = "(\\ (x (Nat -> Nat)) (x 5))"; + let parsed = parse(input).unwrap(); + assert_eq!( + parsed, + Ast::Abstraction( + "x".to_string(), + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + Box::new(Ast::Application( + Box::new(Ast::Variable("x".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))) + )) + ) + ) +} + +#[test] +fn parse_application() { + let input = "((add 5) 6)"; + let parsed = parse(input).unwrap(); + assert_eq!( + parsed, + Ast::Application( + Box::new(Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))) + )), + Box::new(Ast::Constant(Constant::Nat(6))) + ) + ) +}