diff --git a/src/inference/mod.rs b/src/inference/mod.rs index 54812e5..9d9d394 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -16,4 +16,5 @@ pub enum InferError { ExpectedTypeWithTag, DoesNotFitTag(TypeTag, TaggedType), ConfilictingBind, + ConfilictingTags, } diff --git a/src/inference/recursive.rs b/src/inference/recursive.rs index 6669ed1..c6101ea 100644 --- a/src/inference/recursive.rs +++ b/src/inference/recursive.rs @@ -34,7 +34,7 @@ fn infer_type_debrujin_int( } DeBrujinAst::Application(lhs, rhs) => { let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?; - let Some((in_type, out_type)) = left_type.arrow() else { + let Some((in_type, out_type)) = left_type.split_arrow() else { return Err(InferError::NotAFunction); }; let Some(right_type) = diff --git a/src/inference/test.rs b/src/inference/test.rs index a091d06..f2a76a4 100644 --- a/src/inference/test.rs +++ b/src/inference/test.rs @@ -79,6 +79,29 @@ fn infer_add_nat_partial_rec() { ); } +#[test] +fn subtype_constraints() { + let typ = TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(Type::arrow("a", "a").into()), + ); + + let mut st_constraints = Rc::new(RefCell::new(MultiMap::new())); + let ctx = TypeVarCtx::new(); + + let res = typ.make_constraints(st_constraints.clone(), &ctx); + + assert_eq!( + st_constraints + .borrow() + .get(&Type::TypeVariable("?typ_0".to_string())), + vec![&TypeTag::Num] + ); + + assert_eq!(res, Type::arrow("?typ_0", "?typ_0")); +} + #[test] fn infer_add_nat_uni() { let ast: DeBrujinAst = Ast::Application( @@ -93,13 +116,17 @@ fn infer_add_nat_uni() { let mut gamma = HashMap::new(); gamma.insert( "add".to_string(), - TaggedType::Concrete(Type::arrow( - PrimitiveType::Nat, - Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), - )), + TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + "a", + Type::arrow("a", "a"), + ))), + ), ); - let infered = infer_type_rec(&gamma, ast).unwrap(); + let infered = infer_type_uni(&gamma, ast).unwrap(); assert_eq!( infered, TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat)) diff --git a/src/inference/unification.rs b/src/inference/unification.rs index 5a552d7..d79c1e4 100644 --- a/src/inference/unification.rs +++ b/src/inference/unification.rs @@ -1,4 +1,10 @@ -use std::{cell::RefCell, clone, collections::HashMap, rc::Rc}; +use std::{ + cell::RefCell, + clone, + collections::HashMap, + fmt::{Debug, DebugMap}, + rc::Rc, +}; use crate::{ Ast, DeBrujinAst, @@ -10,7 +16,7 @@ use crate::{ use super::InferError; -type TypeVar = TaggedType; +type TypeVar = Type; pub struct TypeVarCtx { counter: RefCell, @@ -23,45 +29,59 @@ impl TypeVarCtx { } } - pub fn get_var(&self) -> TaggedType { + pub fn get_var(&self) -> Type { let mut num = self.counter.borrow_mut(); let res = format!("?typ_{num}"); *num += 1; - Type::TypeVariable(res).into() + Type::TypeVariable(res) } } #[derive(Debug, Clone, PartialEq)] pub enum TypeVarAst { - Abstraction(TypeVar, Ident, TaggedType, Box), // \0:1.2 - Application(TypeVar, Box, Box), // 0 1 - FreeVariable(TypeVar, Ident), // x - BoundVariable(TypeVar, usize), // 1 - Constant(TypeVar, Constant), // true | false | n + Abstraction(TypeVar, Ident, Type, Box), // \0:1.2 + Application(TypeVar, Box, Box), // 0 1 + FreeVariable(TypeVar, Ident), // x + BoundVariable(TypeVar, usize), // 1 + Constant(TypeVar, Constant), // true | false | n } -pub type Constraints = MultiMap; +pub type Constraints = MultiMap; pub(super) fn step_1( ast: DeBrujinAst, gamma_free: &HashMap, - mut gamma_bound: Rc>, - constraints: Rc>, + mut gamma_bound: Rc>, + eq_constraints: Rc>>, + st_constraints: Rc>>, ctx: &TypeVarCtx, ) -> Result<(TypeVarAst, TypeVar), InferError> { match ast { DeBrujinAst::Abstraction(i, Some(typ), ast) => { - let var = ctx.get_var(); + // let var = ctx.get_var(); + + let typ = typ.make_constraints(st_constraints.clone(), ctx); let gamma_ref = Rc::make_mut(&mut gamma_bound); gamma_ref.map_keys(|i| *i += 1); gamma_ref.insert(1, typ.clone()); - let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?; - // RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var)); + let (ast, rhs_var) = step_1( + *ast, + gamma_free, + gamma_bound, + eq_constraints.clone(), + st_constraints.clone(), + ctx, + )?; Ok(( - TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)), - var, + TypeVarAst::Abstraction( + Type::arrow(typ.clone(), rhs_var.clone()), + i, + typ.clone(), + Box::new(ast), + ), + Type::arrow(typ, rhs_var.clone()), )) } DeBrujinAst::Abstraction(i, None, ast) => { @@ -72,11 +92,22 @@ pub(super) fn step_1( gamma_ref.map_keys(|i| *i += 1); gamma_ref.insert(1, typ.clone()); - let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?; - // RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var)); + let (ast, rhs_var) = step_1( + *ast, + gamma_free, + gamma_bound, + eq_constraints.clone(), + st_constraints.clone(), + ctx, + )?; Ok(( - TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)), - var, + TypeVarAst::Abstraction( + Type::arrow(typ.clone(), rhs_var.clone()), + i, + typ.clone(), + Box::new(ast), + ), + Type::arrow(typ, rhs_var.clone()), )) } DeBrujinAst::Application(lhs, rhs) => { @@ -85,17 +116,19 @@ pub(super) fn step_1( *lhs, gamma_free, gamma_bound.clone(), - constraints.clone(), + eq_constraints.clone(), + st_constraints.clone(), ctx, )?; let (rhs, rhs_var) = step_1( *rhs, gamma_free, gamma_bound.clone(), - constraints.clone(), + eq_constraints.clone(), + st_constraints.clone(), ctx, )?; - RefCell::borrow_mut(&constraints).insert(lhs_var, rhs_var.make_arrow(var.clone())); + RefCell::borrow_mut(&eq_constraints).insert(lhs_var, Type::arrow(rhs_var, var.clone())); Ok(( TypeVarAst::Application(var.clone(), Box::new(lhs), Box::new(rhs)), var, @@ -108,7 +141,9 @@ pub(super) fn step_1( .cloned() .ok_or(InferError::NotInContext)?; - RefCell::borrow_mut(&constraints).insert(var.clone(), typ); + let typ = typ.make_constraints(st_constraints.clone(), ctx); + + RefCell::borrow_mut(&eq_constraints).insert(var.clone(), typ); Ok((TypeVarAst::FreeVariable(var.clone(), v), var)) } @@ -119,7 +154,7 @@ pub(super) fn step_1( .cloned() .ok_or(InferError::NotInContext)?; - RefCell::borrow_mut(&constraints).insert(var.clone(), typ); + RefCell::borrow_mut(&eq_constraints).insert(var.clone(), typ); Ok((TypeVarAst::BoundVariable(var.clone(), i), var)) } @@ -132,48 +167,40 @@ pub(super) fn step_1( } .into(); - RefCell::borrow_mut(&constraints).insert(var.clone(), typ); + RefCell::borrow_mut(&eq_constraints).insert(var.clone(), typ); Ok((TypeVarAst::Constant(var.clone(), constant), var)) } } } -pub(super) fn step_2( - mut constraints: Constraints, -) -> Result Result>>, InferError> { - if let Some((s, t)) = constraints.pop() { +pub(super) fn step_2(mut eq_constraints: Constraints) -> Result, InferError> { + if let Some((s, t)) = eq_constraints.pop() { if s == t { - step_2(constraints) + step_2(eq_constraints) } else if s.type_var().is_some_and(|x| !t.name_used(&x)) { let Some(x) = s.type_var() else { unreachable!() }; - constraints.try_map(|(k, v)| { - Ok(( - k.subst_typevar(&x, 0, t.clone())?, - v.subst_typevar(&x, 0, t.clone())?, - )) - })?; - let subst = step_2(constraints)?; - Ok(Some(subst_comp(x, t, subst))) + eq_constraints + .try_map(|(k, v)| Ok((k.subst_typevar(&x, &t)?, v.subst_typevar(&x, &t)?)))?; + let subst = step_2(eq_constraints)?; + Ok(Some(SubstFn::new(x, t, subst))) } else if t.type_var().is_some_and(|x| !s.name_used(&x)) { let Some(x) = t.type_var() else { unreachable!() }; - constraints.try_map(|(k, v)| { - Ok(( - k.subst_typevar(&x, 0, s.clone())?, - v.subst_typevar(&x, 0, s.clone())?, - )) - })?; - - let subst = step_2(constraints)?; - Ok(Some(subst_comp(x, t, subst))) - } else if let (Some((s_lhs, s_rhs)), Some((t_lhs, t_rhs))) = (s.arrow(), t.arrow()) { - constraints.insert(s_lhs, t_lhs); - constraints.insert(s_rhs, t_rhs); - step_2(constraints) + eq_constraints + .try_map(|(k, v)| Ok((k.subst_typevar(&x, &s)?, v.subst_typevar(&x, &s)?)))?; + + let subst = step_2(eq_constraints)?; + Ok(Some(SubstFn::new(x, s, subst))) + } else if let (Some((s_lhs, s_rhs)), Some((t_lhs, t_rhs))) = + (s.split_arrow(), t.split_arrow()) + { + eq_constraints.insert(s_lhs, t_lhs); + eq_constraints.insert(s_rhs, t_rhs); + step_2(eq_constraints) } else { panic!() } @@ -182,50 +209,89 @@ pub(super) fn step_2( } } -impl TypeVarAst { - pub fn subst(self, var: &str, subst: TaggedType) -> Result { +pub trait Subst { + fn subst(self, var: &str, subst: &Type) -> Result + where + Self: Sized; +} + +impl Subst for TypeVarAst { + fn subst(self, var: &str, subst: &Type) -> Result { match self { - TypeVarAst::Abstraction(tagged_type1, ident, tagged_type2, ast) => { - Ok(TypeVarAst::Abstraction( - tagged_type1.subst_typevar(var, 0, subst.clone())?, - ident, - tagged_type2.subst_typevar(var, 0, subst.clone())?, - Box::new(ast.subst(var, subst)?), - )) - } - TypeVarAst::Application(tagged_type, lhs, rhs) => Ok(TypeVarAst::Application( - tagged_type.subst_typevar(var, 0, subst.clone())?, - Box::new(lhs.subst(var, subst.clone())?), - Box::new(rhs.subst(var, subst)?), + TypeVarAst::Abstraction(type1, ident, type2, ast) => Ok(TypeVarAst::Abstraction( + type1.subst_typevar(var, subst)?, + ident, + type2.subst_typevar(var, subst)?, + Box::new(ast.subst(var, subst)?), )), - TypeVarAst::FreeVariable(tagged_type, x) => Ok(TypeVarAst::FreeVariable( - tagged_type.subst_typevar(var, 0, subst)?, - x, + TypeVarAst::Application(typ, lhs, rhs) => Ok(TypeVarAst::Application( + typ.subst_typevar(var, subst)?, + Box::new(lhs.subst(var, subst)?), + Box::new(rhs.subst(var, subst)?), )), - TypeVarAst::BoundVariable(tagged_type, i) => Ok(TypeVarAst::BoundVariable( - tagged_type.subst_typevar(var, 0, subst)?, + TypeVarAst::FreeVariable(typ, x) => { + Ok(TypeVarAst::FreeVariable(typ.subst_typevar(var, &subst)?, x)) + } + TypeVarAst::BoundVariable(typ, i) => Ok(TypeVarAst::BoundVariable( + typ.subst_typevar(var, &subst)?, i, )), - TypeVarAst::Constant(tagged_type, constant) => Ok(TypeVarAst::Constant( - tagged_type.subst_typevar(var, 0, subst)?, + TypeVarAst::Constant(typ, constant) => Ok(TypeVarAst::Constant( + typ.subst_typevar(var, &subst)?, constant, )), } } } -pub fn subst_comp<'a, F>( +impl Subst for Constraints { + fn subst(mut self, var: &str, subst: &Type) -> Result, InferError> { + self.into_iter() + .map::, _>(|(k, v)| { + Ok((k.subst_typevar(var, &subst)?, v)) + }) + .collect() + } +} + +pub struct SubstFn { var: String, - subst: TaggedType, - then: Option, -) -> Box Result + 'a> -where - F: FnOnce(TypeVarAst) -> Result + 'a, -{ - Box::new(move |ast: TypeVarAst| -> Result { - let ast = ast.subst(&var, subst)?; - if let Some(f) = then { f(ast) } else { Ok(ast) } - }) + subst: Type, + then: Option>, +} + +impl Debug for SubstFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut f = f.debug_map(); + self.debug_map(&mut f); + f.finish() + } +} + +impl SubstFn { + pub fn new(var: String, subst: Type, then: Option) -> Self { + Self { + var, + subst, + then: then.map(Box::new), + } + } + + pub fn apply(&self, target: T) -> Result { + let target = target.subst(&self.var, &self.subst); + if let Some(s) = &self.then { + s.apply(target?) + } else { + target + } + } + + fn debug_map(&self, f: &mut DebugMap) { + f.entry(&self.var, &self.subst); + if let Some(s) = &self.then { + s.debug_map(f); + } + } } pub fn infer_type( @@ -233,37 +299,91 @@ pub fn infer_type( ast: DeBrujinAst, ) -> Result { let gamma_bound = Rc::new(VecMap::new()); - let constraints = Rc::new(RefCell::new(MultiMap::new())); + let eq_constraints = Rc::new(RefCell::new(MultiMap::new())); + let st_constraints = Rc::new(RefCell::new(MultiMap::new())); let ctx = TypeVarCtx::new(); - let (ast, _) = step_1(ast, gamma, gamma_bound, constraints.clone(), &ctx)?; - constraints.clone(); - let res = step_2(constraints.take())?.unwrap(); + let (ast, _) = step_1( + ast, + gamma, + gamma_bound, + eq_constraints.clone(), + st_constraints.clone(), + &ctx, + )?; + let res = step_2(eq_constraints.take())?.unwrap(); - let ast = res(ast)?; + let ast = res.apply(ast)?; - fn get_type(ast: TypeVarAst) -> TaggedType { + fn get_type(ast: TypeVarAst) -> Type { match ast { - TypeVarAst::Abstraction(_, _, typ, ast) => typ.make_arrow(get_type(*ast)), - TypeVarAst::Application(tagged_type, _, _) => tagged_type, - TypeVarAst::FreeVariable(tagged_type, _) => tagged_type, - TypeVarAst::BoundVariable(tagged_type, _) => tagged_type, - TypeVarAst::Constant(tagged_type, constant) => tagged_type, + TypeVarAst::Abstraction(_, _, typ, ast) => Type::arrow(typ, get_type(*ast)), + TypeVarAst::Application(typ, _, _) => typ, + TypeVarAst::FreeVariable(typ, _) => typ, + TypeVarAst::BoundVariable(typ, _) => typ, + TypeVarAst::Constant(typ, constant) => typ, } } let typ = get_type(ast); - let mut typ = typ; + let st_constraints = res.apply(st_constraints.take())?; + let st_constraints = st_constraints + .into_iter() + .map(|(typ, tag)| { + if typ.has_tag(&tag) { + Ok((typ, tag)) + } else { + Err(InferError::DoesNotFitTag(tag, typ.into())) + } + }) + .collect::, _>>()?; + + let mut typ: TaggedType = typ.into(); for free_var in typ.free_vars() { - typ = TaggedType::Tagged(TypeTag::Any, free_var, Box::new(typ)); + let tags = st_constraints.get(&Type::TypeVariable(free_var.clone())); + if tags.is_empty() { + typ = TaggedType::Tagged(TypeTag::Any, free_var, Box::new(typ)); + } else { + let tag = tags + .into_iter() + .fold(Some(TypeTag::Any), |acc, t| { + acc.map(|acc| acc.tightens(t.clone())).flatten() + }) + .ok_or(InferError::ConfilictingTags)?; + typ = TaggedType::Tagged(tag, free_var, Box::new(typ)); + } } - Ok(typ.normalise()) + Ok(typ) } impl TaggedType { + pub fn make_constraints( + self, + st_constraints: Rc>>, + ctx: &TypeVarCtx, + ) -> Type { + match self { + TaggedType::Tagged(type_tag, ident, mut tagged_type) => { + let var = ctx.get_var(); + let name = var.type_var().unwrap(); + + tagged_type.map_name(|n| { + if *n == ident { + *n = name.clone() + } + }); + + st_constraints.borrow_mut().insert(var, type_tag); + + tagged_type.make_constraints(st_constraints, ctx) + } + TaggedType::Concrete(c) => c, + } + } + pub fn subst_typevar( self, var: &str, diff --git a/src/multi_map.rs b/src/multi_map.rs index b0a095a..ba69769 100644 --- a/src/multi_map.rs +++ b/src/multi_map.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, mem}; +use std::{fmt::Debug, mem, vec}; pub struct MultiMap { map: Vec<(K, V)>, @@ -61,6 +61,24 @@ impl MultiMap { } } +impl IntoIterator for MultiMap { + type Item = (K, V); + + type IntoIter = vec::IntoIter<(K, V)>; + + fn into_iter(self) -> Self::IntoIter { + self.map.into_iter() + } +} + +impl FromIterator<(K, V)> for MultiMap { + fn from_iter>(iter: T) -> Self { + Self { + map: iter.into_iter().collect(), + } + } +} + impl Default for MultiMap { fn default() -> Self { Self { diff --git a/src/parse/test.rs b/src/parse/test.rs index cf131b8..c01ed64 100644 --- a/src/parse/test.rs +++ b/src/parse/test.rs @@ -251,4 +251,4 @@ fn parse_application() { ); } -// (\x:Any a => a -> a.\y:Num a => a. + 1 (x y)) (\x:Any a => a.x) 2 +// (\x:Any a => a -> a.\y:Num a => a. add 1 (x y)) (\x:Any a => a.x) 2 diff --git a/src/types/mod.rs b/src/types/mod.rs index 3d2379f..505ebab 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -42,7 +42,7 @@ impl From for TaggedType { impl TypeTag { // If one tag tightens the other return that tag, otherwise none, - fn tightens(self, other: Self) -> Option { + pub fn tightens(self, other: Self) -> Option { match (self, other) { (TypeTag::Num, TypeTag::Num) => Some(Self::Num), (TypeTag::Num, TypeTag::Any) => Some(Self::Num), @@ -87,10 +87,10 @@ impl TaggedType { } } - pub fn arrow(self) -> Option<(TaggedType, TaggedType)> { + pub fn split_arrow(self) -> Option<(TaggedType, TaggedType)> { match self { TaggedType::Tagged(type_tag, ident, tagged_type) => { - let (lhs, rhs) = tagged_type.arrow()?; + let (lhs, rhs) = tagged_type.split_arrow()?; Some(( TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs)) .clear_unused_names(), @@ -237,6 +237,13 @@ impl Type { Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) } + pub fn split_arrow(self) -> Option<(Type, Type)> { + match self { + Type::Arrow(lhs, rhs) => Some((*lhs, *rhs)), + _ => None, + } + } + pub fn is_concrete(&self) -> bool { match self { Type::TypeVariable(_) => false, @@ -247,7 +254,7 @@ impl Type { pub fn has_tag(&self, tag: &TypeTag) -> bool { match self { - Type::TypeVariable(_) => false, + Type::TypeVariable(_) => true, Type::Primitive(primitive_type) => match (primitive_type, tag) { (_, TypeTag::Any) => true, (PrimitiveType::Nat, TypeTag::Num) => true,