Initial commit

main
Avery 1 month ago
commit 36aeb10da9
Signed by: Avery
GPG Key ID: 4E53F4CB69B2CC8D

1
.gitignore vendored

@ -0,0 +1 @@
/target

7
Cargo.lock generated

@ -0,0 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "stlc_type_inference"
version = "0.1.0"

@ -0,0 +1,6 @@
[package]
name = "stlc_type_inference"
version = "0.1.0"
edition = "2024"
[dependencies]

@ -0,0 +1,100 @@
#[cfg(test)]
mod test;
use std::{collections::HashMap, rc::Rc};
use crate::{Ast, Constant, Ident, Type};
impl Ast {
pub fn beta_reduce(self) -> Ast {
match self {
Ast::Application(lhs, rhs) => match *lhs {
Ast::Abstraction(var, _, ast) => ast.subst(var, *rhs),
lhs => Ast::Application(Box::new(lhs), rhs),
},
t => t,
}
}
fn subst(self, var: Ident, subst: Ast) -> Ast {
match self {
Ast::Abstraction(var1, typ, ast) => {
if var != var1 {
Ast::Abstraction(var1, typ, Box::new(ast.subst(var, subst)))
} else {
Ast::Abstraction(var1, typ, ast)
}
}
Ast::Application(lhs, rhs) => Ast::Application(
Box::new(lhs.subst(var.clone(), subst.clone())),
Box::new(rhs.subst(var, subst)),
),
Ast::Variable(v) if v == var => subst,
Ast::Variable(v) => Ast::Variable(v),
Ast::Constant(constant) => Ast::Constant(constant),
}
}
}
impl Ast {
pub fn to_de_brujin(self) -> DeBrujinAst {
self.to_de_brujin_inter(Rc::new(HashMap::new()))
}
fn to_de_brujin_inter(self, mut gamma: Rc<HashMap<String, usize>>) -> DeBrujinAst {
match self {
Ast::Abstraction(var, _, ast) => {
let gamma_ref = Rc::make_mut(&mut gamma);
gamma_ref.values_mut().for_each(|v| *v += 1);
gamma_ref.insert(var, 1);
DeBrujinAst::Abstraction(Box::new(ast.to_de_brujin_inter(gamma)))
}
Ast::Application(lhs, rhs) => DeBrujinAst::Application(
Box::new(lhs.to_de_brujin_inter(gamma.clone())),
Box::new(rhs.to_de_brujin_inter(gamma)),
),
Ast::Variable(v) => {
if let Some(c) = gamma.get(&v) {
DeBrujinAst::BoundVariable(*c)
} else {
DeBrujinAst::FreeVariable(v)
}
}
Ast::Constant(constant) => DeBrujinAst::Constant(constant),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeBrujinAst {
Abstraction(Box<DeBrujinAst>), // \:1.2
Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1
FreeVariable(String), // x
BoundVariable(usize), // 1
Constant(Constant), // true | false | n
}
impl DeBrujinAst {
pub fn beta_reduce(self) -> DeBrujinAst {
match self {
DeBrujinAst::Application(lhs, rhs) => match *lhs {
DeBrujinAst::Abstraction(ast) => ast.subst_bound(1, *rhs),
lhs => DeBrujinAst::Application(Box::new(lhs), rhs),
},
a => a,
}
}
fn subst_bound(self, depth: usize, subst: DeBrujinAst) -> DeBrujinAst {
match self {
DeBrujinAst::Abstraction(ast) => ast.subst_bound(depth + 1, subst),
DeBrujinAst::Application(lhs, rhs) => DeBrujinAst::Application(
Box::new(lhs.subst_bound(depth, subst.clone())),
Box::new(rhs.subst_bound(depth, subst)),
),
DeBrujinAst::BoundVariable(n) if n == depth => subst,
a => a,
}
}
}

@ -0,0 +1,68 @@
use crate::{Ast, Constant, PrimitiveType, Type, exec::DeBrujinAst as DBAst};
#[test]
fn beta_reduce() {
let input = Ast::Application(
Box::new(Ast::Abstraction(
"x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
)),
Box::new(Ast::Variable("y".to_string())),
);
let reduced = input.beta_reduce();
assert_eq!(
reduced,
Ast::Application(
Box::new(Ast::Variable("y".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
),
)
}
#[test]
fn to_de_brujin_ast_simple() {
let input = Ast::Abstraction(
"x".to_string(),
PrimitiveType::Nat.into(),
Box::new(Ast::Abstraction(
"x".to_string(),
PrimitiveType::Nat.into(),
Box::new(Ast::Variable("x".to_string())),
)),
);
let de_brujin = input.to_de_brujin();
assert_eq!(
de_brujin,
DBAst::Abstraction(Box::new(DBAst::Abstraction(Box::new(
DBAst::BoundVariable(1)
))))
)
}
#[test]
fn de_brujin_beta_reduce() {
let input = Ast::Application(
Box::new(Ast::Abstraction(
"x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
)),
Box::new(Ast::Variable("y".to_string())),
);
let dbast = input.to_de_brujin();
let reduced = dbast.beta_reduce();
assert_eq!(
reduced,
DBAst::Application(
Box::new(DBAst::FreeVariable("y".to_string())),
Box::new(DBAst::Constant(Constant::Nat(5))),
),
)
}

@ -0,0 +1,37 @@
use std::{collections::HashMap, error::Error, rc::Rc};
use crate::{Ast, Constant, Ident, PrimitiveType, Type};
#[derive(Debug)]
pub enum InferError {
NotAFunction,
MismatchedType,
NotInContext,
}
#[cfg(test)]
mod test;
pub fn infer_type(mut gamma: Rc<HashMap<Ident, Type>>, ast: Ast) -> Result<Type, InferError> {
match ast {
Ast::Abstraction(arg, arg_type, ast) => {
Rc::make_mut(&mut gamma).insert(arg, arg_type.clone());
let out_type = infer_type(gamma, *ast)?;
Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type)))
}
Ast::Application(left, right) => {
let left_type = infer_type(gamma.clone(), *left)?;
let Type::Arrow(in_type, out_type) = left_type else {
return Err(InferError::NotAFunction);
};
let right_type = infer_type(gamma, *right)?;
if *in_type != right_type {
return Err(InferError::MismatchedType);
}
Ok(*out_type)
}
Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext),
Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)),
Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)),
}
}

@ -0,0 +1,49 @@
use std::{collections::HashMap, rc::Rc};
use crate::{Ast, Constant, PrimitiveType, Type};
use super::infer_type;
#[test]
fn infer_id_type() {
let ast = Ast::Abstraction(
"x".to_string(),
Type::Primitive(PrimitiveType::Nat),
Box::new(Ast::Variable("x".to_string())),
);
let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap();
assert_eq!(
infered,
Type::Arrow(
Box::new(Type::Primitive(PrimitiveType::Nat)),
Box::new(Type::Primitive(PrimitiveType::Nat))
)
)
}
#[test]
fn infer_addition_result_type() {
let ast = Ast::Application(
Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
Box::new(Ast::Constant(Constant::Nat(7))),
);
let mut gamma = HashMap::new();
gamma.insert(
"add".to_string(),
Type::Arrow(
Box::new(Type::Primitive(PrimitiveType::Nat)),
Box::new(Type::Arrow(
Box::new(Type::Primitive(PrimitiveType::Nat)),
Box::new(Type::Primitive(PrimitiveType::Nat)),
)),
),
);
let infered = infer_type(Rc::new(gamma), ast).unwrap();
assert_eq!(infered, Type::Primitive(PrimitiveType::Nat));
}

@ -0,0 +1,85 @@
#![allow(unused)]
use std::fmt::Display;
mod exec;
mod inference;
mod parse;
type Ident = String;
#[derive(Debug, Clone, PartialEq, Eq)]
enum PrimitiveType {
Nat,
Bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Type {
Primitive(PrimitiveType), // Bool | Nat
Arrow(Box<Type>, Box<Type>), // 0 -> 1
}
impl From<PrimitiveType> for Type {
fn from(value: PrimitiveType) -> Self {
Type::Primitive(value)
}
}
impl Type {
fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self {
Self::Arrow(Box::new(t1.into()), Box::new(t2.into()))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Constant {
Nat(usize),
Bool(bool),
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum Ast {
Abstraction(Ident, Type, Box<Ast>), // \0:1.2
Application(Box<Ast>, Box<Ast>), // 0 1
Variable(Ident), // x
Constant(Constant), // true | false | n
}
impl Display for Ast {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"),
Ast::Application(lhs, rhs) => write!(f, "(lhs rhs)"),
Ast::Variable(v) => write!(f, "{v}"),
Ast::Constant(constant) => write!(f, "{constant}"),
}
}
}
impl Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Primitive(primitive_type) => write!(f, "{primitive_type}"),
Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"),
}
}
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::Nat => write!(f, "Nat"),
PrimitiveType::Bool => write!(f, "Bool"),
}
}
}
impl Display for Constant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Constant::Nat(n) => write!(f, "{n}"),
Constant::Bool(b) => write!(f, "{b}"),
}
}
}

@ -0,0 +1 @@
fn main() {}

@ -0,0 +1,182 @@
use sexpr::{Sexpr, parse_string};
use crate::{Ast, Constant, PrimitiveType, Type};
mod sexpr;
#[cfg(test)]
mod test;
#[derive(Debug)]
pub enum ParseError {
UnexpectedParenClose,
UnexpectedEof,
TrailingTokens,
TrailingExpr,
ToplevelSymbol,
InvalidSymbol,
UnexpectedEndOfList,
UnknownType,
ExpectedList,
ExpectedSymbol,
ExpectedLambda,
ExpectedIdent,
ExpectedType,
NotAType,
ExpectedBody,
ExpectedArrow,
ExpectedOneOf(Vec<String>, String),
}
fn expect_symbol(ast: Option<Sexpr>) -> Result<String, ParseError> {
match ast {
Some(Sexpr::Symbol(s)) => Ok(s),
Some(l) => Err(ParseError::ExpectedSymbol),
None => Err(ParseError::ExpectedSymbol),
}
}
fn expect_ident(ast: Option<Sexpr>) -> Result<String, ParseError> {
let sym = expect_symbol(ast)?;
if is_ident(&sym) {
Ok(sym)
} else {
Err(ParseError::ExpectedIdent)
}
}
fn expect_list(ast: Option<Sexpr>) -> Result<Vec<Sexpr>, ParseError> {
match ast {
Some(Sexpr::List(l)) => Ok(l),
Some(l) => Err(ParseError::ExpectedList),
None => Err(ParseError::ExpectedList),
}
}
fn expect_one_of<T>(options: &[T], item: String) -> Result<String, ParseError>
where
T: PartialEq<String> + Into<String> + Clone,
{
if options.iter().find(|e| **e == item).is_some() {
Ok(item)
} else {
Err(ParseError::ExpectedOneOf(
options
.iter()
.map(|t| Into::<String>::into(t.clone()))
.collect(),
item,
))
}
}
fn expect_empty<T, I: Iterator<Item = T>>(mut iter: I) -> Result<(), ParseError> {
match iter.next() {
Some(_) => Err(ParseError::TrailingTokens),
None => Ok(()),
}
}
pub fn parse(input: &str) -> Result<Ast, ParseError> {
let ast = parse_string(input)?;
match ast {
Sexpr::Symbol(s) => parse_symbol(s),
list => parse_intern(list),
}
}
fn parse_intern(ast: Sexpr) -> Result<Ast, ParseError> {
match ast {
Sexpr::Symbol(s) => parse_symbol(s),
Sexpr::List(sexprs) => {
let mut iter = sexprs.into_iter();
match iter.next() {
Some(Sexpr::Symbol(sym)) => {
if sym == "\\" {
let bind = expect_list(iter.next())?;
let mut bind = bind.into_iter();
let ident = expect_ident(bind.next())?;
let typ = parse_type(&bind.next().ok_or(ParseError::ExpectedType)?)?;
expect_empty(bind)?;
let ast = Ast::Abstraction(
ident,
typ,
Box::new(parse_intern(iter.next().ok_or(ParseError::ExpectedBody)?)?),
);
expect_empty(iter)?;
Ok(ast)
} else {
let ast = parse_symbol(sym)?;
if let Some(e) = iter.next() {
let rhs = parse_intern(e)?;
expect_empty(iter)?;
Ok(Ast::Application(Box::new(ast), Box::new(rhs)))
} else {
Ok(ast)
}
}
}
Some(app_left) => {
if let Some(app_right) = iter.next() {
expect_empty(iter)?;
Ok(Ast::Application(
Box::new(parse_intern(app_left)?),
// Make it back into an Sexpr so we can feed it to parse intern
Box::new(parse_intern(app_right)?),
))
} else {
Err(ParseError::UnexpectedEndOfList)
}
}
None => Err(ParseError::UnexpectedEndOfList),
}
}
}
}
fn parse_symbol(s: String) -> Result<Ast, ParseError> {
if let Ok(n) = s.parse::<usize>() {
Ok(Ast::Constant(Constant::Nat(n)))
} else if let Ok(b) = s.parse::<bool>() {
Ok(Ast::Constant(Constant::Bool(b)))
} else if is_ident(&s) {
Ok(Ast::Variable(s))
} else {
Err(ParseError::InvalidSymbol)
}
}
fn is_ident(s: &str) -> bool {
s.starts_with(|c: char| c.is_alphabetic()) && s.chars().all(|c| c.is_alphanumeric())
}
fn parse_type(ast: &Sexpr) -> Result<Type, ParseError> {
match ast {
Sexpr::Symbol(s) => parse_prim_type(s),
Sexpr::List(sexprs) => parse_type_list(sexprs),
}
}
fn parse_type_list(typ: &[Sexpr]) -> Result<Type, ParseError> {
let Some(t) = typ.get(0) else { todo!() };
if typ.get(1).is_some() {
let arr = expect_symbol(typ.get(1).cloned())?;
if arr != "->" {
return Err(ParseError::ExpectedArrow);
}
Ok(Type::Arrow(
Box::new(parse_type(t)?),
Box::new(parse_type_list(&typ[2..])?),
))
} else {
parse_type(t)
}
}
fn parse_prim_type(typ: &str) -> Result<Type, ParseError> {
match typ {
"Bool" => Ok(Type::Primitive(PrimitiveType::Bool)),
"Nat" => Ok(Type::Primitive(PrimitiveType::Nat)),
_ => Err(ParseError::UnknownType),
}
}

@ -0,0 +1,92 @@
use std::iter::Peekable;
use std::ops::{Deref, RangeInclusive};
use std::usize;
use std::vec::IntoIter;
use super::ParseError;
#[derive(Debug, PartialEq, Clone)]
pub enum Token {
LeftParen,
RightParen,
Symbol(String),
}
#[derive(Debug, PartialEq, Clone)]
pub enum Sexpr {
Symbol(String),
List(Vec<Sexpr>),
}
impl Sexpr {
pub fn symbol(self) -> Option<String> {
match self {
Sexpr::Symbol(item) => Some(item),
_ => None,
}
}
pub fn list(self) -> Option<Vec<Sexpr>> {
match self {
Sexpr::List(item) => Some(item),
_ => None,
}
}
}
pub fn tokenize(input: &str) -> Vec<Token> {
let mut tokens = Vec::new();
// let mut chars = input.chars().peekable();
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
match c {
'(' => tokens.push(Token::LeftParen),
')' => tokens.push(Token::RightParen),
_ if c.is_whitespace() => (),
_ => {
let mut symbol = c.to_string();
while let Some(c) = chars.peek() {
if c.is_whitespace() || *c == '(' || *c == ')' {
break;
}
symbol.push(*c);
chars.next();
}
tokens.push(Token::Symbol(symbol));
}
}
}
tokens
}
fn parse_expr(tokens: &mut Peekable<IntoIter<Token>>) -> Result<Sexpr, ParseError> {
match tokens.next() {
Some(Token::LeftParen) => {
let mut list = Vec::new();
while !matches!(tokens.peek(), Some(Token::RightParen,)) {
list.push(parse_expr(tokens)?);
}
let Some(Token::RightParen) = tokens.next() else {
unreachable!()
};
Ok(Sexpr::List(list))
}
Some(Token::RightParen) => Err(ParseError::UnexpectedParenClose),
Some(Token::Symbol(s)) => Ok(Sexpr::Symbol(s)),
None => Err(ParseError::UnexpectedEof),
}
}
pub fn parse(tokens: Vec<Token>) -> Result<Sexpr, ParseError> {
let mut tokens = tokens.into_iter().peekable();
let ast = parse_expr(&mut tokens)?;
if tokens.peek().is_some() {
return Err(ParseError::TrailingTokens);
};
Ok(ast)
}
pub fn parse_string(src: &str) -> Result<Sexpr, ParseError> {
let tokens = tokenize(src);
parse(tokens)
}

@ -0,0 +1,93 @@
use crate::{
Ast, Constant, PrimitiveType, Type,
parse::sexpr::{Sexpr, parse_string},
};
use super::{parse, parse_type};
#[test]
fn parse_to_sexpr() {
let input = "((\\x:Nat.x) (5))";
let parsed = parse_string(input).unwrap();
assert_eq!(
parsed,
Sexpr::List(vec![
Sexpr::List(vec![Sexpr::Symbol("\\x:Nat.x".to_string())]),
Sexpr::List(vec![Sexpr::Symbol("5".to_string())])
])
);
}
#[test]
fn parse_prim_type() {
let input = Sexpr::Symbol("Nat".to_string());
let parsed = parse_type(&input).unwrap();
assert_eq!(parsed, Type::Primitive(PrimitiveType::Nat))
}
#[test]
fn parse_simpl_arr_type() {
let input = Sexpr::List(vec![
Sexpr::Symbol("Nat".to_string()),
Sexpr::Symbol("->".to_string()),
Sexpr::Symbol("Nat".to_string()),
]);
let parsed = parse_type(&input).unwrap();
assert_eq!(parsed, Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat))
}
#[test]
fn parse_apply_arr_type() {
let input = Sexpr::List(vec![
Sexpr::List(vec![
Sexpr::Symbol("Nat".to_string()),
Sexpr::Symbol("->".to_string()),
Sexpr::Symbol("Nat".to_string()),
]),
Sexpr::Symbol("->".to_string()),
Sexpr::Symbol("Nat".to_string()),
Sexpr::Symbol("->".to_string()),
Sexpr::Symbol("Nat".to_string()),
]);
let parsed = parse_type(&input).unwrap();
assert_eq!(
parsed,
Type::arrow(
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)
)
)
}
#[test]
fn parse_abstraction() {
let input = "(\\ (x (Nat -> Nat)) (x 5))";
let parsed = parse(input).unwrap();
assert_eq!(
parsed,
Ast::Abstraction(
"x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5)))
))
)
)
}
#[test]
fn parse_application() {
let input = "((add 5) 6)";
let parsed = parse(input).unwrap();
assert_eq!(
parsed,
Ast::Application(
Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5)))
)),
Box::new(Ast::Constant(Constant::Nat(6)))
)
)
}
Loading…
Cancel
Save