From 766da2593e86957833eb4f701f77ca0b21ce5e16 Mon Sep 17 00:00:00 2001 From: Avery Date: Mon, 17 Mar 2025 21:51:27 +0100 Subject: [PATCH] Inference by unification --- src/exec/builtins.rs | 10 +- src/exec/mod.rs | 10 +- src/exec/test.rs | 12 +- src/inference/mod.rs | 70 +------- src/inference/recursive.rs | 59 +++++++ src/inference/test.rs | 176 +++++++++++++++++-- src/inference/unification.rs | 329 +++++++++++++++++++++++++++++++++++ src/lib.rs | 4 +- src/main.rs | 6 +- src/multi_map.rs | 86 +++++++++ src/parse/mod.rs | 11 +- src/parse/test.rs | 28 +-- src/parser.lalrpop | 3 +- src/types/mod.rs | 173 +++++++++++++++--- src/vec_map.rs | 8 + 15 files changed, 846 insertions(+), 139 deletions(-) create mode 100644 src/inference/recursive.rs create mode 100644 src/inference/unification.rs create mode 100644 src/multi_map.rs diff --git a/src/exec/builtins.rs b/src/exec/builtins.rs index 419d842..3547702 100644 --- a/src/exec/builtins.rs +++ b/src/exec/builtins.rs @@ -9,11 +9,11 @@ use super::DeBrujinAst; #[derive(Clone)] pub enum DeBrujinBuiltInAst { - Abstraction(Ident, TaggedType, Box), // \:1.2 - Application(Box, Box), // 0 1 - FreeVariable(String), // x - BoundVariable(usize), // 1 - Constant(Constant), // true | false | n + Abstraction(Ident, Option, Box), // \:1.2 + Application(Box, Box), // 0 1 + FreeVariable(String), // x + BoundVariable(usize), // 1 + Constant(Constant), // true | false | n Builtin(Rc), } diff --git a/src/exec/mod.rs b/src/exec/mod.rs index c8a746e..ef400ed 100644 --- a/src/exec/mod.rs +++ b/src/exec/mod.rs @@ -48,11 +48,11 @@ impl Ast { #[derive(Debug, Clone, PartialEq)] pub enum DeBrujinAst { - Abstraction(Ident, TaggedType, Box), // \:1.2 - Application(Box, Box), // 0 1 - FreeVariable(String), // x - BoundVariable(usize), // 1 - Constant(Constant), // true | false | n + Abstraction(Ident, Option, Box), // \:1.2 + Application(Box, Box), // 0 1 + FreeVariable(String), // x + BoundVariable(usize), // 1 + Constant(Constant), // true | false | n } impl Into for DeBrujinAst { diff --git a/src/exec/test.rs b/src/exec/test.rs index 7552893..341f1b2 100644 --- a/src/exec/test.rs +++ b/src/exec/test.rs @@ -12,10 +12,10 @@ use super::builtins::Builtin; fn to_de_brujin_ast_simple() { let input = Ast::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(Ast::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(Ast::Variable("x".to_string())), )), ); @@ -24,10 +24,10 @@ fn to_de_brujin_ast_simple() { de_brujin, DBAst::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(DBAst::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(DBAst::BoundVariable(1)) )) ) @@ -39,7 +39,7 @@ fn de_brujin_beta_reduce() { let input = Ast::Application( Box::new(Ast::Abstraction( "x".to_string(), - Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into(), + Some(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()), Box::new(Ast::Application( Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Constant(Constant::Nat(5))), @@ -63,7 +63,7 @@ fn to_and_from_de_brujin_is_id() { let input = Ast::Application( Box::new(Ast::Abstraction( "x".to_string(), - Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into(), + Some(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()), Box::new(Ast::Application( Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Constant(Constant::Nat(5))), diff --git a/src/inference/mod.rs b/src/inference/mod.rs index a005a5e..54812e5 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -1,11 +1,11 @@ use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc}; -use crate::{ - DeBrujinAst, - parse::{Ast, Constant}, - types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, - vec_map::VecMap, -}; +use crate::types::{TaggedType, TypeTag}; + +pub mod recursive; +#[cfg(test)] +mod test; +pub mod unification; #[derive(Debug)] pub enum InferError { @@ -14,60 +14,6 @@ pub enum InferError { NotInContext, ExpectedConreteType, ExpectedTypeWithTag, -} - -#[cfg(test)] -mod test; - -pub fn infer_type( - gamma: &HashMap, - ast: DeBrujinAst, -) -> Result { - infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast) -} - -fn infer_type_debrujin_int( - gamma_free: &HashMap, - mut gamma_bound: Rc>, - ast: DeBrujinAst, -) -> Result { - match ast { - DeBrujinAst::Abstraction(_, arg_type, ast) => { - let gamma_ref = Rc::make_mut(&mut gamma_bound); - gamma_ref.map_keys(|i| *i += 1); - gamma_ref.insert(1, arg_type.clone().into()); - - let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?; - // TODO: Fix this hack - let typ = arg_type.make_arrow( - out_type - .to_concrete() - .ok_or(InferError::ExpectedConreteType)?, - ); - Ok(typ) - } - DeBrujinAst::Application(lhs, rhs) => { - let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?; - let Some((in_type, out_type)) = left_type.arrow() else { - return Err(InferError::NotAFunction); - }; - let Some(right_type) = - infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete() - else { - return Err(InferError::ExpectedConreteType); - }; - let out_type = match in_type.matches(&right_type) { - (true, None) => out_type, - (true, Some(a)) => out_type.specialize(&a, &right_type), - (false, _) => return Err(InferError::MismatchedType), - }; - Ok(out_type) - } - DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext), - // compiler bug if not present - DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()), - DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()), - DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()), - DeBrujinAst::Constant(_) => unreachable!(), - } + DoesNotFitTag(TypeTag, TaggedType), + ConfilictingBind, } diff --git a/src/inference/recursive.rs b/src/inference/recursive.rs new file mode 100644 index 0000000..6669ed1 --- /dev/null +++ b/src/inference/recursive.rs @@ -0,0 +1,59 @@ +use std::{collections::HashMap, rc::Rc}; + +use crate::{ + DeBrujinAst, + parse::Constant, + types::{Ident, PrimitiveType, TaggedType, Type}, + vec_map::VecMap, +}; + +use super::InferError; + +pub fn infer_type( + gamma: &HashMap, + ast: DeBrujinAst, +) -> Result { + infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast) +} + +fn infer_type_debrujin_int( + gamma_free: &HashMap, + mut gamma_bound: Rc>, + ast: DeBrujinAst, +) -> Result { + match ast { + DeBrujinAst::Abstraction(_, arg_type, ast) => { + let gamma_ref = Rc::make_mut(&mut gamma_bound); + gamma_ref.map_keys(|i| *i += 1); + gamma_ref.insert(1, arg_type.clone().unwrap().into()); + + let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?; + // TODO: Fix this hack + let typ = arg_type.unwrap().make_arrow(out_type); + Ok(typ) + } + DeBrujinAst::Application(lhs, rhs) => { + let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?; + let Some((in_type, out_type)) = left_type.arrow() else { + return Err(InferError::NotAFunction); + }; + let Some(right_type) = + infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete() + else { + return Err(InferError::ExpectedConreteType); + }; + let out_type = match in_type.matches(&right_type) { + (true, None) => out_type, + (true, Some(a)) => out_type.specialize(&a, &right_type), + (false, _) => return Err(InferError::MismatchedType), + }; + Ok(out_type) + } + DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext), + // compiler bug if not present + DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()), + DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()), + DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()), + DeBrujinAst::Constant(_) => unreachable!(), + } +} diff --git a/src/inference/test.rs b/src/inference/test.rs index 130ad96..a091d06 100644 --- a/src/inference/test.rs +++ b/src/inference/test.rs @@ -1,14 +1,21 @@ -use std::{collections::HashMap, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, rc::Rc}; use crate::{ DeBrujinAst, - inference::infer_type, + inference::{ + recursive::infer_type as infer_type_rec, unification::Constraints, + unification::infer_type as infer_type_uni, + }, + multi_map::MultiMap, parse::{Ast, Constant}, types::{PrimitiveType, TaggedType, Type, TypeTag}, + vec_map::VecMap, }; +use super::unification::{TypeVarAst, TypeVarCtx}; + #[test] -fn infer_add_nat() { +fn infer_add_nat_rec() { let ast: DeBrujinAst = Ast::Application( Box::new(Ast::Application( Box::new(Ast::Variable("add".to_string())), @@ -25,16 +32,16 @@ fn infer_add_nat() { TypeTag::Num, "a".to_string(), Box::new(TaggedType::Concrete(Type::Arrow( - Box::new(Type::Generic("a".to_string())), + Box::new(Type::TypeVariable("a".to_string())), Box::new(Type::Arrow( - Box::new(Type::Generic("a".to_string())), - Box::new(Type::Generic("a".to_string())), + Box::new(Type::TypeVariable("a".to_string())), + Box::new(Type::TypeVariable("a".to_string())), )), ))), ), ); - let infered = infer_type(&gamma, ast).unwrap(); + let infered = infer_type_rec(&gamma, ast).unwrap(); assert_eq!( infered, TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat)) @@ -42,7 +49,7 @@ fn infer_add_nat() { } #[test] -fn infer_add_nat_partial() { +fn infer_add_nat_partial_rec() { let ast: DeBrujinAst = Ast::Application( Box::new(Ast::Variable("add".to_string())), Box::new(Ast::Constant(Constant::Nat(5))), @@ -56,18 +63,163 @@ fn infer_add_nat_partial() { TypeTag::Num, "a".to_string(), Box::new(TaggedType::Concrete(Type::Arrow( - Box::new(Type::Generic("a".to_string())), + Box::new(Type::TypeVariable("a".to_string())), Box::new(Type::Arrow( - Box::new(Type::Generic("a".to_string())), - Box::new(Type::Generic("a".to_string())), + Box::new(Type::TypeVariable("a".to_string())), + Box::new(Type::TypeVariable("a".to_string())), )), ))), ), ); - let infered = infer_type(&gamma, ast).unwrap(); + let infered = infer_type_rec(&gamma, ast).unwrap(); assert_eq!( infered, TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) ); } + +#[test] +fn infer_add_nat_uni() { + let ast: DeBrujinAst = Ast::Application( + Box::new(Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + Box::new(Ast::Constant(Constant::Nat(7))), + ) + .into(); + + let mut gamma = HashMap::new(); + gamma.insert( + "add".to_string(), + TaggedType::Concrete(Type::arrow( + PrimitiveType::Nat, + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + )), + ); + + let infered = infer_type_rec(&gamma, ast).unwrap(); + assert_eq!( + infered, + TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat)) + ); +} + +#[test] +fn infer_add_nat_partial_uni() { + let ast: DeBrujinAst = Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + ) + .into(); + + let mut gamma = HashMap::new(); + gamma.insert( + "add".to_string(), + TaggedType::Concrete(Type::arrow( + PrimitiveType::Nat, + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + )), + ); + + let typ = infer_type_uni(&gamma, ast).unwrap(); + + assert_eq!( + typ, + TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) + ); +} + +#[test] +fn infer_id_uni() { + let ast: DeBrujinAst = Ast::Abstraction( + "x".to_string(), + None, + Box::new(Ast::Variable("x".to_string())), + ) + .into(); + + let mut gamma = HashMap::new(); + let typ = infer_type_uni(&gamma, ast).unwrap(); + + assert_eq!( + typ, + TaggedType::Tagged( + TypeTag::Any, + "?typ_1".to_string(), + Box::new(TaggedType::Concrete(Type::arrow("?typ_1", "?typ_1"))) + ) + ); +} + +#[test] +fn subst_type_var() { + let typ = TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow("a", "?typ_4"))), + ); + + let subst = TaggedType::Tagged( + TypeTag::Any, + "b".to_string(), + Box::new(TaggedType::Concrete(Type::arrow("b", "b"))), + ); + + let typ = typ.subst_typevar("?typ_4", 0, subst).unwrap(); + + assert_eq!( + typ, + TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Tagged( + TypeTag::Any, + "b".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + "a", + Type::arrow("b", "b") + ))) + )) + ) + ); + + let typ = TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow("a", "?typ_4"))), + ); + + let subst = TaggedType::Tagged( + TypeTag::Any, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow("a", "a"))), + ); + + let typ = typ.subst_typevar("?typ_4", 0, subst).unwrap(); + + assert!(match typ { + TaggedType::Tagged(TypeTag::Num, a, tagged_type) => match *tagged_type { + TaggedType::Tagged(TypeTag::Any, b, tagged_type) => match *tagged_type { + TaggedType::Concrete(typ) => match typ { + Type::Arrow(lhs, rhs) => match (*lhs, *rhs) { + (Type::TypeVariable(a1), Type::Arrow(b1, b2)) if a1 == a => { + match (*b1, *b2) { + (Type::TypeVariable(b1), Type::TypeVariable(b2)) => { + (b1 == b) && (b2 == b) + } + _ => false, + } + } + _ => false, + }, + _ => false, + }, + _ => false, + }, + _ => false, + }, + _ => false, + }); +} diff --git a/src/inference/unification.rs b/src/inference/unification.rs new file mode 100644 index 0000000..5a552d7 --- /dev/null +++ b/src/inference/unification.rs @@ -0,0 +1,329 @@ +use std::{cell::RefCell, clone, collections::HashMap, rc::Rc}; + +use crate::{ + Ast, DeBrujinAst, + multi_map::MultiMap, + parse::Constant, + types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, + vec_map::VecMap, +}; + +use super::InferError; + +type TypeVar = TaggedType; + +pub struct TypeVarCtx { + counter: RefCell, +} + +impl TypeVarCtx { + pub fn new() -> Self { + Self { + counter: RefCell::new(0), + } + } + + pub fn get_var(&self) -> TaggedType { + let mut num = self.counter.borrow_mut(); + let res = format!("?typ_{num}"); + *num += 1; + Type::TypeVariable(res).into() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum TypeVarAst { + Abstraction(TypeVar, Ident, TaggedType, Box), // \0:1.2 + Application(TypeVar, Box, Box), // 0 1 + FreeVariable(TypeVar, Ident), // x + BoundVariable(TypeVar, usize), // 1 + Constant(TypeVar, Constant), // true | false | n +} + +pub type Constraints = MultiMap; + +pub(super) fn step_1( + ast: DeBrujinAst, + gamma_free: &HashMap, + mut gamma_bound: Rc>, + constraints: Rc>, + ctx: &TypeVarCtx, +) -> Result<(TypeVarAst, TypeVar), InferError> { + match ast { + DeBrujinAst::Abstraction(i, Some(typ), ast) => { + let var = ctx.get_var(); + + let gamma_ref = Rc::make_mut(&mut gamma_bound); + gamma_ref.map_keys(|i| *i += 1); + gamma_ref.insert(1, typ.clone()); + + let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?; + // RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var)); + Ok(( + TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)), + var, + )) + } + DeBrujinAst::Abstraction(i, None, ast) => { + let var = ctx.get_var(); + + let typ = ctx.get_var(); + let gamma_ref = Rc::make_mut(&mut gamma_bound); + gamma_ref.map_keys(|i| *i += 1); + gamma_ref.insert(1, typ.clone()); + + let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?; + // RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var)); + Ok(( + TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)), + var, + )) + } + DeBrujinAst::Application(lhs, rhs) => { + let var = ctx.get_var(); + let (lhs, lhs_var) = step_1( + *lhs, + gamma_free, + gamma_bound.clone(), + constraints.clone(), + ctx, + )?; + let (rhs, rhs_var) = step_1( + *rhs, + gamma_free, + gamma_bound.clone(), + constraints.clone(), + ctx, + )?; + RefCell::borrow_mut(&constraints).insert(lhs_var, rhs_var.make_arrow(var.clone())); + Ok(( + TypeVarAst::Application(var.clone(), Box::new(lhs), Box::new(rhs)), + var, + )) + } + DeBrujinAst::FreeVariable(v) => { + let var = ctx.get_var(); + let typ = gamma_free + .get(&v) + .cloned() + .ok_or(InferError::NotInContext)?; + + RefCell::borrow_mut(&constraints).insert(var.clone(), typ); + + Ok((TypeVarAst::FreeVariable(var.clone(), v), var)) + } + DeBrujinAst::BoundVariable(i) => { + let var = ctx.get_var(); + let typ = gamma_bound + .get(&i) + .cloned() + .ok_or(InferError::NotInContext)?; + + RefCell::borrow_mut(&constraints).insert(var.clone(), typ); + + Ok((TypeVarAst::BoundVariable(var.clone(), i), var)) + } + DeBrujinAst::Constant(constant) => { + let var = ctx.get_var(); + let typ = match constant { + Constant::Nat(_) => Type::Primitive(PrimitiveType::Nat), + Constant::Float(_) => Type::Primitive(PrimitiveType::Float), + Constant::Bool(_) => Type::Primitive(PrimitiveType::Bool), + } + .into(); + + RefCell::borrow_mut(&constraints).insert(var.clone(), typ); + + Ok((TypeVarAst::Constant(var.clone(), constant), var)) + } + } +} + +pub(super) fn step_2( + mut constraints: Constraints, +) -> Result Result>>, InferError> { + if let Some((s, t)) = constraints.pop() { + if s == t { + step_2(constraints) + } else if s.type_var().is_some_and(|x| !t.name_used(&x)) { + let Some(x) = s.type_var() else { + unreachable!() + }; + constraints.try_map(|(k, v)| { + Ok(( + k.subst_typevar(&x, 0, t.clone())?, + v.subst_typevar(&x, 0, t.clone())?, + )) + })?; + let subst = step_2(constraints)?; + Ok(Some(subst_comp(x, t, subst))) + } else if t.type_var().is_some_and(|x| !s.name_used(&x)) { + let Some(x) = t.type_var() else { + unreachable!() + }; + constraints.try_map(|(k, v)| { + Ok(( + k.subst_typevar(&x, 0, s.clone())?, + v.subst_typevar(&x, 0, s.clone())?, + )) + })?; + + let subst = step_2(constraints)?; + Ok(Some(subst_comp(x, t, subst))) + } else if let (Some((s_lhs, s_rhs)), Some((t_lhs, t_rhs))) = (s.arrow(), t.arrow()) { + constraints.insert(s_lhs, t_lhs); + constraints.insert(s_rhs, t_rhs); + step_2(constraints) + } else { + panic!() + } + } else { + Ok(None) + } +} + +impl TypeVarAst { + pub fn subst(self, var: &str, subst: TaggedType) -> Result { + match self { + TypeVarAst::Abstraction(tagged_type1, ident, tagged_type2, ast) => { + Ok(TypeVarAst::Abstraction( + tagged_type1.subst_typevar(var, 0, subst.clone())?, + ident, + tagged_type2.subst_typevar(var, 0, subst.clone())?, + Box::new(ast.subst(var, subst)?), + )) + } + TypeVarAst::Application(tagged_type, lhs, rhs) => Ok(TypeVarAst::Application( + tagged_type.subst_typevar(var, 0, subst.clone())?, + Box::new(lhs.subst(var, subst.clone())?), + Box::new(rhs.subst(var, subst)?), + )), + TypeVarAst::FreeVariable(tagged_type, x) => Ok(TypeVarAst::FreeVariable( + tagged_type.subst_typevar(var, 0, subst)?, + x, + )), + TypeVarAst::BoundVariable(tagged_type, i) => Ok(TypeVarAst::BoundVariable( + tagged_type.subst_typevar(var, 0, subst)?, + i, + )), + TypeVarAst::Constant(tagged_type, constant) => Ok(TypeVarAst::Constant( + tagged_type.subst_typevar(var, 0, subst)?, + constant, + )), + } + } +} + +pub fn subst_comp<'a, F>( + var: String, + subst: TaggedType, + then: Option, +) -> Box Result + 'a> +where + F: FnOnce(TypeVarAst) -> Result + 'a, +{ + Box::new(move |ast: TypeVarAst| -> Result { + let ast = ast.subst(&var, subst)?; + if let Some(f) = then { f(ast) } else { Ok(ast) } + }) +} + +pub fn infer_type( + gamma: &HashMap, + ast: DeBrujinAst, +) -> Result { + let gamma_bound = Rc::new(VecMap::new()); + let constraints = Rc::new(RefCell::new(MultiMap::new())); + let ctx = TypeVarCtx::new(); + + let (ast, _) = step_1(ast, gamma, gamma_bound, constraints.clone(), &ctx)?; + constraints.clone(); + let res = step_2(constraints.take())?.unwrap(); + + let ast = res(ast)?; + + fn get_type(ast: TypeVarAst) -> TaggedType { + match ast { + TypeVarAst::Abstraction(_, _, typ, ast) => typ.make_arrow(get_type(*ast)), + TypeVarAst::Application(tagged_type, _, _) => tagged_type, + TypeVarAst::FreeVariable(tagged_type, _) => tagged_type, + TypeVarAst::BoundVariable(tagged_type, _) => tagged_type, + TypeVarAst::Constant(tagged_type, constant) => tagged_type, + } + } + + let typ = get_type(ast); + + let mut typ = typ; + + for free_var in typ.free_vars() { + typ = TaggedType::Tagged(TypeTag::Any, free_var, Box::new(typ)); + } + + Ok(typ.normalise()) +} + +impl TaggedType { + pub fn subst_typevar( + self, + var: &str, + diversifier: usize, + mut subst: TaggedType, + ) -> Result { + match self { + TaggedType::Tagged(t, i, tagged_type) => { + if subst.name_used(&i) { + subst.map_name(|f| { + if *f == i { + *f = format!("{i}+{diversifier}") + } + }); + } + Ok(TaggedType::Tagged( + t, + i, + Box::new(tagged_type.subst_typevar(var, diversifier + 1, subst)?), + )) + } + TaggedType::Concrete(t) => subst.subst_into(var, diversifier, t), + } + } + + fn subst_into( + self, + var: &str, + diversifier: usize, + mut target: Type, + ) -> Result { + match self { + TaggedType::Tagged(t, i, tagged_type) => { + if target.name_used(&i) { + target.map_name(&|f| { + if *f == i { + *f = format!("{i}+{diversifier}") + } + }); + } + Ok(TaggedType::Tagged( + t, + i, + Box::new(tagged_type.subst_into(var, diversifier + 1, target)?), + )) + } + TaggedType::Concrete(c) => Ok(TaggedType::Concrete(target.subst_typevar(var, &c)?)), + } + } +} + +impl Type { + pub fn subst_typevar(self, var: &str, subst: &Type) -> Result { + match self { + Type::TypeVariable(v) if v == var => Ok(subst.clone()), + Type::Arrow(lhs, rhs) => Ok(Type::Arrow( + Box::new(lhs.subst_typevar(var, subst)?), + Box::new(rhs.subst_typevar(var, subst)?), + )), + t => Ok(t), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 6b18129..ef41667 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,11 +4,13 @@ use std::fmt::Display; mod exec; mod inference; +mod multi_map; mod parse; mod types; mod vec_map; pub use exec::{Builtin, DeBrujinAst, builtin}; -pub use inference::infer_type; +pub use inference::recursive::infer_type as infer_type_rec; +pub use inference::unification::infer_type as infer_type_uni; use lalrpop_util::lalrpop_mod; pub use parse::{Ast, parse_ast_str, parse_type_str}; diff --git a/src/main.rs b/src/main.rs index fe364d7..929cfd5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use std::{ }; use stlc_type_inference::{ - Ast, Builtin, DeBrujinAst, builtin, infer_type, parse_ast_str, parse_type_str, + Ast, Builtin, DeBrujinAst, builtin, infer_type_uni, parse_ast_str, parse_type_str, }; macro_rules! repl_err { @@ -35,7 +35,7 @@ fn main() { if let Some((cmd, expr)) = tl.split_once(' ') { match cmd { "t" => match parse_ast_str(expr) { - Ok(a) => match infer_type(&gamma, a.into()) { + Ok(a) => match infer_type_uni(&gamma, a.into()) { Ok(t) => println!("{t}"), Err(e) => repl_err!("Could not infer type {e:?}"), }, @@ -77,7 +77,7 @@ fn main() { }; let ast: DeBrujinAst = ast.into(); - let typ = match infer_type(&gamma, ast.clone()) { + let typ = match infer_type_uni(&gamma, ast.clone()) { Ok(t) => t, Err(e) => repl_err!("Could not infer type {e:?}"), }; diff --git a/src/multi_map.rs b/src/multi_map.rs new file mode 100644 index 0000000..b0a095a --- /dev/null +++ b/src/multi_map.rs @@ -0,0 +1,86 @@ +use std::{fmt::Debug, mem}; + +pub struct MultiMap { + map: Vec<(K, V)>, +} + +impl MultiMap { + pub fn new() -> Self { + Self { map: Vec::new() } + } + + pub fn insert(&mut self, key: K, val: V) { + self.map.push((key, val)); + } + + pub fn get(&self, key: &K) -> Vec<&V> { + self.map + .iter() + .filter_map(|(k, v)| if k == key { Some(v) } else { None }) + .collect() + } + + pub fn map_keys(&mut self, f: F) { + self.map.iter_mut().for_each(|(k, _)| f(k)) + } + + pub fn map(&mut self, f: F) { + self.map.iter_mut().for_each(f); + } + + pub fn try_map Result<(K, V), E>>(&mut self, f: F) -> Result<(), E> { + let vec = mem::take(&mut self.map); + let vec = vec.into_iter().map(f).collect::, _>>()?; + self.map = vec; + Ok(()) + } + + pub fn find bool>(&self, f: F) -> Option<&(K, V)> { + self.map.iter().find(f) + } + + pub fn find_remove bool>(&mut self, mut f: F) -> Option<(K, V)> { + let idx = self + .map + .iter() + .enumerate() + .find_map(|(n, ref kv)| if f(kv) { Some(n) } else { None })?; + Some(self.map.swap_remove(idx)) + } + + pub fn len(&self) -> usize { + self.map.len() + } + + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } + + pub fn pop(&mut self) -> Option<(K, V)> { + self.map.pop() + } +} + +impl Default for MultiMap { + fn default() -> Self { + Self { + map: Vec::default(), + } + } +} + +impl Clone for MultiMap { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + } + } +} + +impl Debug for MultiMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_map() + .entries(self.map.iter().map(|&(ref k, ref v)| (k, v))) + .finish() + } +} diff --git a/src/parse/mod.rs b/src/parse/mod.rs index 9c3eee5..0e93bed 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -40,10 +40,10 @@ pub enum Constant { #[derive(Debug, Clone, PartialEq)] pub enum Ast { - Abstraction(Ident, TaggedType, Box), // \0:1.2 - Application(Box, Box), // 0 1 - Variable(Ident), // x - Constant(Constant), // true | false | n + Abstraction(Ident, Option, Box), // \0:1.2 + Application(Box, Box), // 0 1 + Variable(Ident), // x + Constant(Constant), // true | false | n } pub fn parse_ast_str(src: &str) -> Result> { @@ -84,7 +84,8 @@ impl Type { impl Display for Ast { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Ast::Abstraction(var, typ, ast) => write!(f, "(\\{var}:{typ}.{ast})"), + Ast::Abstraction(var, Some(typ), ast) => write!(f, "(\\{var}:{typ}.{ast})"), + Ast::Abstraction(var, None, ast) => write!(f, "(\\{var}.{ast})"), Ast::Application(lhs, rhs) => write!(f, "{lhs} {rhs}"), Ast::Variable(v) => write!(f, "{v}"), Ast::Constant(constant) => write!(f, "{constant}"), diff --git a/src/parse/test.rs b/src/parse/test.rs index a096ec2..cf131b8 100644 --- a/src/parse/test.rs +++ b/src/parse/test.rs @@ -93,7 +93,7 @@ fn parse_abstraction() { ast, Ast::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(Ast::Variable("x".to_string())) ) ); @@ -103,10 +103,10 @@ fn parse_abstraction() { ast, Ast::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(Ast::Abstraction( "y".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(Ast::Variable("x".to_string())) )) ) @@ -119,18 +119,18 @@ fn parse_abstraction() { ast, Ast::Abstraction( "x".to_string(), - TaggedType::Tagged( + Some(TaggedType::Tagged( TypeTag::Any, "a".to_string(), Box::new(TaggedType::Concrete("a".into())) - ), + )), Box::new(Ast::Abstraction( "y".to_string(), - TaggedType::Tagged( + Some(TaggedType::Tagged( TypeTag::Any, "b".to_string(), Box::new(TaggedType::Concrete("b".into())) - ), + )), Box::new(Ast::Variable("x".to_string())) )) ) @@ -146,7 +146,7 @@ fn parse_application() { Ast::Application( Box::new(Ast::Abstraction( "x".to_string(), - Type::Primitive(PrimitiveType::Nat).into(), + Some(Type::Primitive(PrimitiveType::Nat).into()), Box::new(Ast::Variable("x".to_string())) )), Box::new(Ast::Constant(Constant::Nat(5))) @@ -212,18 +212,18 @@ fn parse_application() { Box::new(Ast::Application( Box::new(Ast::Abstraction( "x".to_string(), - TaggedType::Tagged( + Some(TaggedType::Tagged( TypeTag::Any, "a".to_string(), Box::new(TaggedType::Concrete(Type::arrow("a", "a"))) - ), + )), Box::new(Ast::Abstraction( "y".to_string(), - TaggedType::Tagged( + Some(TaggedType::Tagged( TypeTag::Num, "a".to_string(), Box::new(TaggedType::Concrete("a".into())) - ), + )), Box::new(Ast::Application( Box::new(Ast::Application( Box::new(Ast::Variable("add".to_string())), @@ -238,11 +238,11 @@ fn parse_application() { )), Box::new(Ast::Abstraction( "x".to_string(), - TaggedType::Tagged( + Some(TaggedType::Tagged( TypeTag::Any, "a".to_string(), Box::new(TaggedType::Concrete("a".into())) - ), + )), Box::new(Ast::Variable("x".to_string())) )) )), diff --git a/src/parser.lalrpop b/src/parser.lalrpop index e66b50c..0645e2b 100644 --- a/src/parser.lalrpop +++ b/src/parser.lalrpop @@ -45,7 +45,8 @@ extern { pub Ast: Ast = { Term => <>, - r"\" ":" "." => Ast::Abstraction(x, t, Box::new(ast)), + r"\" ":" "." => Ast::Abstraction(x, Some(t), Box::new(ast)), + r"\" "." => Ast::Abstraction(x, None, Box::new(ast)), }; diff --git a/src/types/mod.rs b/src/types/mod.rs index 4a5766f..3d2379f 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,29 +1,35 @@ -use std::{fmt::Display, str}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Display, + mem, + rc::Rc, + str, +}; pub type Ident = String; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum PrimitiveType { Nat, Bool, Float, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeTag { Num, Any, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TaggedType { Tagged(TypeTag, Ident, Box), Concrete(Type), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Type { - Generic(Ident), // a + TypeVariable(Ident), // a Primitive(PrimitiveType), // Bool | Nat Arrow(Box, Box), // 0 -> 1 } @@ -34,6 +40,18 @@ impl From for TaggedType { } } +impl TypeTag { + // If one tag tightens the other return that tag, otherwise none, + fn tightens(self, other: Self) -> Option { + match (self, other) { + (TypeTag::Num, TypeTag::Num) => Some(Self::Num), + (TypeTag::Num, TypeTag::Any) => Some(Self::Num), + (TypeTag::Any, TypeTag::Num) => Some(Self::Num), + (TypeTag::Any, TypeTag::Any) => Some(Self::Any), + } + } +} + impl TaggedType { pub fn r#type(&self) -> &Type { match self { @@ -52,13 +70,7 @@ impl TaggedType { pub fn to_concrete(self) -> Option { match self { TaggedType::Tagged(type_tag, _, tagged_type) => None, - TaggedType::Concrete(t) => { - if t.is_concrete() { - Some(t) - } else { - None - } - } + TaggedType::Concrete(t) => Some(t), } } @@ -92,12 +104,21 @@ impl TaggedType { } } - pub fn make_arrow(self, rhs: Type) -> Self { + pub fn make_arrow(self, rhs: TaggedType) -> Self { match self { TaggedType::Tagged(type_tag, ident, tagged_type) => { TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow(rhs))) } - TaggedType::Concrete(t) => Type::arrow(t, rhs).into(), + TaggedType::Concrete(t) => rhs.make_arrow_lhs(t), + } + } + + fn make_arrow_lhs(self, lhs: Type) -> Self { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow_lhs(lhs))) + } + TaggedType::Concrete(t) => Type::arrow(lhs, t).into(), } } @@ -113,20 +134,85 @@ impl TaggedType { } } - fn name_used(&self, ident: &str) -> bool { + pub fn map_name(&mut self, f: F) { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + f(ident); + tagged_type.map_name(f); + } + TaggedType::Concrete(t) => t.map_name(&f), + } + } + + pub fn type_var(&self) -> Option { + match self { + TaggedType::Tagged(type_tag, _, tagged_type) => tagged_type.type_var(), + TaggedType::Concrete(c) => c.type_var(), + } + } + + pub fn free_vars(&self) -> HashSet { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + let mut vars = tagged_type.free_vars(); + vars.retain(|v| v != ident); + vars + } + TaggedType::Concrete(c) => c.type_vars(), + } + } + + pub fn name_used(&self, ident: &str) -> bool { match self { TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident), TaggedType::Concrete(c) => c.name_used(ident), } } + pub fn is_concrete(&self) -> bool { + match self { + TaggedType::Tagged(type_tag, _, tagged_type) => tagged_type.is_concrete(), + TaggedType::Concrete(c) => c.is_concrete(), + } + } + + pub fn normalise(self) -> TaggedType { + let only_used = self.clear_unused_names(); + fn dedup(this: TaggedType, mut used: Rc>) -> TaggedType { + match this { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + if !used.contains_key(&ident) { + Rc::make_mut(&mut used).insert(ident, type_tag); + dedup(*tagged_type, used) + } else { + let Some(tag) = TypeTag::tightens( + Rc::make_mut(&mut used).remove(&ident).unwrap(), + type_tag, + ) else { + todo!() + }; + Rc::make_mut(&mut used).insert(ident.clone(), tag.clone()); + dedup(*tagged_type, used) + } + } + TaggedType::Concrete(c) => mem::take(Rc::make_mut(&mut used)) + .into_iter() + .fold(TaggedType::Concrete(c), |acc, (i, t)| { + TaggedType::Tagged(t, i, Box::new(acc)) + }), + } + } + + dedup(only_used, Rc::new(HashMap::new())) + } + fn clear_unused_names(self) -> TaggedType { match self { TaggedType::Tagged(type_tag, ident, tagged_type) => { if tagged_type.name_used(&ident) { - TaggedType::Tagged(type_tag, ident, tagged_type) + TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.clear_unused_names())) } else { - *tagged_type + tagged_type.clear_unused_names() } } t => t, @@ -142,7 +228,7 @@ impl From for Type { impl From<&str> for Type { fn from(value: &str) -> Self { - Self::Generic(value.to_string()) + Self::TypeVariable(value.to_string()) } } @@ -153,7 +239,7 @@ impl Type { pub fn is_concrete(&self) -> bool { match self { - Type::Generic(_) => false, + Type::TypeVariable(_) => false, Type::Primitive(primitive_type) => true, Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(), } @@ -161,7 +247,7 @@ impl Type { pub fn has_tag(&self, tag: &TypeTag) -> bool { match self { - Type::Generic(_) => false, + Type::TypeVariable(_) => false, Type::Primitive(primitive_type) => match (primitive_type, tag) { (_, TypeTag::Any) => true, (PrimitiveType::Nat, TypeTag::Num) => true, @@ -171,17 +257,54 @@ impl Type { } } - fn name_used(&self, ident: &str) -> bool { + pub fn map_name(&mut self, f: &F) { + match self { + Type::TypeVariable(v) => f(v), + Type::Arrow(lhs, rhs) => { + lhs.map_name(f); + rhs.map_name(f); + } + _ => {} + } + } + + pub fn name_used(&self, ident: &str) -> bool { match self { - Type::Generic(i) if *i == ident => true, + Type::TypeVariable(i) if *i == ident => true, Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident), _ => false, } } + pub fn type_var(&self) -> Option { + match self { + Type::TypeVariable(v) => Some(v.clone()), + Type::Primitive(primitive_type) => None, + Type::Arrow(_, _) => None, + } + } + + pub fn type_vars(&self) -> HashSet { + match self { + Type::TypeVariable(v) => { + let mut set = HashSet::new(); + set.insert(v.to_string()); + set + } + Type::Primitive(primitive_type) => HashSet::new(), + Type::Arrow(lhs, rhs) => { + let mut vars = lhs.type_vars(); + for var in rhs.type_vars().into_iter() { + vars.insert(var); + } + vars + } + } + } + fn specialize(self, ident: &str, typ: &Type) -> Type { match self { - Type::Generic(i) if i == ident => typ.clone(), + Type::TypeVariable(i) if i == ident => typ.clone(), Type::Arrow(lhs, rhs) => Type::Arrow( Box::new(lhs.specialize(ident, typ)), Box::new(rhs.specialize(ident, typ)), @@ -214,7 +337,7 @@ impl Display for TaggedType { impl Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Type::Generic(i) => write!(f, "{i}"), + Type::TypeVariable(i) => write!(f, "{i}"), Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), } diff --git a/src/vec_map.rs b/src/vec_map.rs index dde0f68..646b24c 100644 --- a/src/vec_map.rs +++ b/src/vec_map.rs @@ -29,6 +29,14 @@ impl VecMap { } } +impl Default for VecMap { + fn default() -> Self { + Self { + map: Vec::default(), + } + } +} + impl Clone for VecMap { fn clone(&self) -> Self { Self {