parent
							
								
									305e95846d
								
							
						
					
					
						commit
						766da2593e
					
				| @ -0,0 +1,59 @@ | ||||
| use std::{collections::HashMap, rc::Rc}; | ||||
| 
 | ||||
| use crate::{ | ||||
|     DeBrujinAst, | ||||
|     parse::Constant, | ||||
|     types::{Ident, PrimitiveType, TaggedType, Type}, | ||||
|     vec_map::VecMap, | ||||
| }; | ||||
| 
 | ||||
| use super::InferError; | ||||
| 
 | ||||
| pub fn infer_type( | ||||
|     gamma: &HashMap<Ident, TaggedType>, | ||||
|     ast: DeBrujinAst, | ||||
| ) -> Result<TaggedType, InferError> { | ||||
|     infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast) | ||||
| } | ||||
| 
 | ||||
| fn infer_type_debrujin_int( | ||||
|     gamma_free: &HashMap<Ident, TaggedType>, | ||||
|     mut gamma_bound: Rc<VecMap<usize, TaggedType>>, | ||||
|     ast: DeBrujinAst, | ||||
| ) -> Result<TaggedType, InferError> { | ||||
|     match ast { | ||||
|         DeBrujinAst::Abstraction(_, arg_type, ast) => { | ||||
|             let gamma_ref = Rc::make_mut(&mut gamma_bound); | ||||
|             gamma_ref.map_keys(|i| *i += 1); | ||||
|             gamma_ref.insert(1, arg_type.clone().unwrap().into()); | ||||
| 
 | ||||
|             let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?; | ||||
|             // TODO: Fix this hack
 | ||||
|             let typ = arg_type.unwrap().make_arrow(out_type); | ||||
|             Ok(typ) | ||||
|         } | ||||
|         DeBrujinAst::Application(lhs, rhs) => { | ||||
|             let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?; | ||||
|             let Some((in_type, out_type)) = left_type.arrow() else { | ||||
|                 return Err(InferError::NotAFunction); | ||||
|             }; | ||||
|             let Some(right_type) = | ||||
|                 infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete() | ||||
|             else { | ||||
|                 return Err(InferError::ExpectedConreteType); | ||||
|             }; | ||||
|             let out_type = match in_type.matches(&right_type) { | ||||
|                 (true, None) => out_type, | ||||
|                 (true, Some(a)) => out_type.specialize(&a, &right_type), | ||||
|                 (false, _) => return Err(InferError::MismatchedType), | ||||
|             }; | ||||
|             Ok(out_type) | ||||
|         } | ||||
|         DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext), | ||||
|         // compiler bug if not present
 | ||||
|         DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()), | ||||
|         DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()), | ||||
|         DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()), | ||||
|         DeBrujinAst::Constant(_) => unreachable!(), | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,329 @@ | ||||
| use std::{cell::RefCell, clone, collections::HashMap, rc::Rc}; | ||||
| 
 | ||||
| use crate::{ | ||||
|     Ast, DeBrujinAst, | ||||
|     multi_map::MultiMap, | ||||
|     parse::Constant, | ||||
|     types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, | ||||
|     vec_map::VecMap, | ||||
| }; | ||||
| 
 | ||||
| use super::InferError; | ||||
| 
 | ||||
| type TypeVar = TaggedType; | ||||
| 
 | ||||
| pub struct TypeVarCtx { | ||||
|     counter: RefCell<usize>, | ||||
| } | ||||
| 
 | ||||
| impl TypeVarCtx { | ||||
|     pub fn new() -> Self { | ||||
|         Self { | ||||
|             counter: RefCell::new(0), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn get_var(&self) -> TaggedType { | ||||
|         let mut num = self.counter.borrow_mut(); | ||||
|         let res = format!("?typ_{num}"); | ||||
|         *num += 1; | ||||
|         Type::TypeVariable(res).into() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq)] | ||||
| pub enum TypeVarAst { | ||||
|     Abstraction(TypeVar, Ident, TaggedType, Box<TypeVarAst>), // \0:1.2
 | ||||
|     Application(TypeVar, Box<TypeVarAst>, Box<TypeVarAst>),   // 0 1
 | ||||
|     FreeVariable(TypeVar, Ident),                             // x
 | ||||
|     BoundVariable(TypeVar, usize),                            // 1
 | ||||
|     Constant(TypeVar, Constant),                              // true | false | n
 | ||||
| } | ||||
| 
 | ||||
| pub type Constraints = MultiMap<TypeVar, TaggedType>; | ||||
| 
 | ||||
| pub(super) fn step_1( | ||||
|     ast: DeBrujinAst, | ||||
|     gamma_free: &HashMap<Ident, TaggedType>, | ||||
|     mut gamma_bound: Rc<VecMap<usize, TaggedType>>, | ||||
|     constraints: Rc<RefCell<Constraints>>, | ||||
|     ctx: &TypeVarCtx, | ||||
| ) -> Result<(TypeVarAst, TypeVar), InferError> { | ||||
|     match ast { | ||||
|         DeBrujinAst::Abstraction(i, Some(typ), ast) => { | ||||
|             let var = ctx.get_var(); | ||||
| 
 | ||||
|             let gamma_ref = Rc::make_mut(&mut gamma_bound); | ||||
|             gamma_ref.map_keys(|i| *i += 1); | ||||
|             gamma_ref.insert(1, typ.clone()); | ||||
| 
 | ||||
|             let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?; | ||||
|             // RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var));
 | ||||
|             Ok(( | ||||
|                 TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)), | ||||
|                 var, | ||||
|             )) | ||||
|         } | ||||
|         DeBrujinAst::Abstraction(i, None, ast) => { | ||||
|             let var = ctx.get_var(); | ||||
| 
 | ||||
|             let typ = ctx.get_var(); | ||||
|             let gamma_ref = Rc::make_mut(&mut gamma_bound); | ||||
|             gamma_ref.map_keys(|i| *i += 1); | ||||
|             gamma_ref.insert(1, typ.clone()); | ||||
| 
 | ||||
|             let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?; | ||||
|             // RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var));
 | ||||
|             Ok(( | ||||
|                 TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)), | ||||
|                 var, | ||||
|             )) | ||||
|         } | ||||
|         DeBrujinAst::Application(lhs, rhs) => { | ||||
|             let var = ctx.get_var(); | ||||
|             let (lhs, lhs_var) = step_1( | ||||
|                 *lhs, | ||||
|                 gamma_free, | ||||
|                 gamma_bound.clone(), | ||||
|                 constraints.clone(), | ||||
|                 ctx, | ||||
|             )?; | ||||
|             let (rhs, rhs_var) = step_1( | ||||
|                 *rhs, | ||||
|                 gamma_free, | ||||
|                 gamma_bound.clone(), | ||||
|                 constraints.clone(), | ||||
|                 ctx, | ||||
|             )?; | ||||
|             RefCell::borrow_mut(&constraints).insert(lhs_var, rhs_var.make_arrow(var.clone())); | ||||
|             Ok(( | ||||
|                 TypeVarAst::Application(var.clone(), Box::new(lhs), Box::new(rhs)), | ||||
|                 var, | ||||
|             )) | ||||
|         } | ||||
|         DeBrujinAst::FreeVariable(v) => { | ||||
|             let var = ctx.get_var(); | ||||
|             let typ = gamma_free | ||||
|                 .get(&v) | ||||
|                 .cloned() | ||||
|                 .ok_or(InferError::NotInContext)?; | ||||
| 
 | ||||
|             RefCell::borrow_mut(&constraints).insert(var.clone(), typ); | ||||
| 
 | ||||
|             Ok((TypeVarAst::FreeVariable(var.clone(), v), var)) | ||||
|         } | ||||
|         DeBrujinAst::BoundVariable(i) => { | ||||
|             let var = ctx.get_var(); | ||||
|             let typ = gamma_bound | ||||
|                 .get(&i) | ||||
|                 .cloned() | ||||
|                 .ok_or(InferError::NotInContext)?; | ||||
| 
 | ||||
|             RefCell::borrow_mut(&constraints).insert(var.clone(), typ); | ||||
| 
 | ||||
|             Ok((TypeVarAst::BoundVariable(var.clone(), i), var)) | ||||
|         } | ||||
|         DeBrujinAst::Constant(constant) => { | ||||
|             let var = ctx.get_var(); | ||||
|             let typ = match constant { | ||||
|                 Constant::Nat(_) => Type::Primitive(PrimitiveType::Nat), | ||||
|                 Constant::Float(_) => Type::Primitive(PrimitiveType::Float), | ||||
|                 Constant::Bool(_) => Type::Primitive(PrimitiveType::Bool), | ||||
|             } | ||||
|             .into(); | ||||
| 
 | ||||
|             RefCell::borrow_mut(&constraints).insert(var.clone(), typ); | ||||
| 
 | ||||
|             Ok((TypeVarAst::Constant(var.clone(), constant), var)) | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub(super) fn step_2( | ||||
|     mut constraints: Constraints, | ||||
| ) -> Result<Option<Box<dyn FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError>>>, InferError> { | ||||
|     if let Some((s, t)) = constraints.pop() { | ||||
|         if s == t { | ||||
|             step_2(constraints) | ||||
|         } else if s.type_var().is_some_and(|x| !t.name_used(&x)) { | ||||
|             let Some(x) = s.type_var() else { | ||||
|                 unreachable!() | ||||
|             }; | ||||
|             constraints.try_map(|(k, v)| { | ||||
|                 Ok(( | ||||
|                     k.subst_typevar(&x, 0, t.clone())?, | ||||
|                     v.subst_typevar(&x, 0, t.clone())?, | ||||
|                 )) | ||||
|             })?; | ||||
|             let subst = step_2(constraints)?; | ||||
|             Ok(Some(subst_comp(x, t, subst))) | ||||
|         } else if t.type_var().is_some_and(|x| !s.name_used(&x)) { | ||||
|             let Some(x) = t.type_var() else { | ||||
|                 unreachable!() | ||||
|             }; | ||||
|             constraints.try_map(|(k, v)| { | ||||
|                 Ok(( | ||||
|                     k.subst_typevar(&x, 0, s.clone())?, | ||||
|                     v.subst_typevar(&x, 0, s.clone())?, | ||||
|                 )) | ||||
|             })?; | ||||
| 
 | ||||
|             let subst = step_2(constraints)?; | ||||
|             Ok(Some(subst_comp(x, t, subst))) | ||||
|         } else if let (Some((s_lhs, s_rhs)), Some((t_lhs, t_rhs))) = (s.arrow(), t.arrow()) { | ||||
|             constraints.insert(s_lhs, t_lhs); | ||||
|             constraints.insert(s_rhs, t_rhs); | ||||
|             step_2(constraints) | ||||
|         } else { | ||||
|             panic!() | ||||
|         } | ||||
|     } else { | ||||
|         Ok(None) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl TypeVarAst { | ||||
|     pub fn subst(self, var: &str, subst: TaggedType) -> Result<TypeVarAst, InferError> { | ||||
|         match self { | ||||
|             TypeVarAst::Abstraction(tagged_type1, ident, tagged_type2, ast) => { | ||||
|                 Ok(TypeVarAst::Abstraction( | ||||
|                     tagged_type1.subst_typevar(var, 0, subst.clone())?, | ||||
|                     ident, | ||||
|                     tagged_type2.subst_typevar(var, 0, subst.clone())?, | ||||
|                     Box::new(ast.subst(var, subst)?), | ||||
|                 )) | ||||
|             } | ||||
|             TypeVarAst::Application(tagged_type, lhs, rhs) => Ok(TypeVarAst::Application( | ||||
|                 tagged_type.subst_typevar(var, 0, subst.clone())?, | ||||
|                 Box::new(lhs.subst(var, subst.clone())?), | ||||
|                 Box::new(rhs.subst(var, subst)?), | ||||
|             )), | ||||
|             TypeVarAst::FreeVariable(tagged_type, x) => Ok(TypeVarAst::FreeVariable( | ||||
|                 tagged_type.subst_typevar(var, 0, subst)?, | ||||
|                 x, | ||||
|             )), | ||||
|             TypeVarAst::BoundVariable(tagged_type, i) => Ok(TypeVarAst::BoundVariable( | ||||
|                 tagged_type.subst_typevar(var, 0, subst)?, | ||||
|                 i, | ||||
|             )), | ||||
|             TypeVarAst::Constant(tagged_type, constant) => Ok(TypeVarAst::Constant( | ||||
|                 tagged_type.subst_typevar(var, 0, subst)?, | ||||
|                 constant, | ||||
|             )), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| pub fn subst_comp<'a, F>( | ||||
|     var: String, | ||||
|     subst: TaggedType, | ||||
|     then: Option<F>, | ||||
| ) -> Box<dyn FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError> + 'a> | ||||
| where | ||||
|     F: FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError> + 'a, | ||||
| { | ||||
|     Box::new(move |ast: TypeVarAst| -> Result<TypeVarAst, InferError> { | ||||
|         let ast = ast.subst(&var, subst)?; | ||||
|         if let Some(f) = then { f(ast) } else { Ok(ast) } | ||||
|     }) | ||||
| } | ||||
| 
 | ||||
| pub fn infer_type( | ||||
|     gamma: &HashMap<Ident, TaggedType>, | ||||
|     ast: DeBrujinAst, | ||||
| ) -> Result<TaggedType, InferError> { | ||||
|     let gamma_bound = Rc::new(VecMap::new()); | ||||
|     let constraints = Rc::new(RefCell::new(MultiMap::new())); | ||||
|     let ctx = TypeVarCtx::new(); | ||||
| 
 | ||||
|     let (ast, _) = step_1(ast, gamma, gamma_bound, constraints.clone(), &ctx)?; | ||||
|     constraints.clone(); | ||||
|     let res = step_2(constraints.take())?.unwrap(); | ||||
| 
 | ||||
|     let ast = res(ast)?; | ||||
| 
 | ||||
|     fn get_type(ast: TypeVarAst) -> TaggedType { | ||||
|         match ast { | ||||
|             TypeVarAst::Abstraction(_, _, typ, ast) => typ.make_arrow(get_type(*ast)), | ||||
|             TypeVarAst::Application(tagged_type, _, _) => tagged_type, | ||||
|             TypeVarAst::FreeVariable(tagged_type, _) => tagged_type, | ||||
|             TypeVarAst::BoundVariable(tagged_type, _) => tagged_type, | ||||
|             TypeVarAst::Constant(tagged_type, constant) => tagged_type, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     let typ = get_type(ast); | ||||
| 
 | ||||
|     let mut typ = typ; | ||||
| 
 | ||||
|     for free_var in typ.free_vars() { | ||||
|         typ = TaggedType::Tagged(TypeTag::Any, free_var, Box::new(typ)); | ||||
|     } | ||||
| 
 | ||||
|     Ok(typ.normalise()) | ||||
| } | ||||
| 
 | ||||
| impl TaggedType { | ||||
|     pub fn subst_typevar( | ||||
|         self, | ||||
|         var: &str, | ||||
|         diversifier: usize, | ||||
|         mut subst: TaggedType, | ||||
|     ) -> Result<TaggedType, InferError> { | ||||
|         match self { | ||||
|             TaggedType::Tagged(t, i, tagged_type) => { | ||||
|                 if subst.name_used(&i) { | ||||
|                     subst.map_name(|f| { | ||||
|                         if *f == i { | ||||
|                             *f = format!("{i}+{diversifier}") | ||||
|                         } | ||||
|                     }); | ||||
|                 } | ||||
|                 Ok(TaggedType::Tagged( | ||||
|                     t, | ||||
|                     i, | ||||
|                     Box::new(tagged_type.subst_typevar(var, diversifier + 1, subst)?), | ||||
|                 )) | ||||
|             } | ||||
|             TaggedType::Concrete(t) => subst.subst_into(var, diversifier, t), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn subst_into( | ||||
|         self, | ||||
|         var: &str, | ||||
|         diversifier: usize, | ||||
|         mut target: Type, | ||||
|     ) -> Result<TaggedType, InferError> { | ||||
|         match self { | ||||
|             TaggedType::Tagged(t, i, tagged_type) => { | ||||
|                 if target.name_used(&i) { | ||||
|                     target.map_name(&|f| { | ||||
|                         if *f == i { | ||||
|                             *f = format!("{i}+{diversifier}") | ||||
|                         } | ||||
|                     }); | ||||
|                 } | ||||
|                 Ok(TaggedType::Tagged( | ||||
|                     t, | ||||
|                     i, | ||||
|                     Box::new(tagged_type.subst_into(var, diversifier + 1, target)?), | ||||
|                 )) | ||||
|             } | ||||
|             TaggedType::Concrete(c) => Ok(TaggedType::Concrete(target.subst_typevar(var, &c)?)), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Type { | ||||
|     pub fn subst_typevar(self, var: &str, subst: &Type) -> Result<Type, InferError> { | ||||
|         match self { | ||||
|             Type::TypeVariable(v) if v == var => Ok(subst.clone()), | ||||
|             Type::Arrow(lhs, rhs) => Ok(Type::Arrow( | ||||
|                 Box::new(lhs.subst_typevar(var, subst)?), | ||||
|                 Box::new(rhs.subst_typevar(var, subst)?), | ||||
|             )), | ||||
|             t => Ok(t), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,86 @@ | ||||
| use std::{fmt::Debug, mem}; | ||||
| 
 | ||||
| pub struct MultiMap<K: Eq, V> { | ||||
|     map: Vec<(K, V)>, | ||||
| } | ||||
| 
 | ||||
| impl<K: Eq, V> MultiMap<K, V> { | ||||
|     pub fn new() -> Self { | ||||
|         Self { map: Vec::new() } | ||||
|     } | ||||
| 
 | ||||
|     pub fn insert(&mut self, key: K, val: V) { | ||||
|         self.map.push((key, val)); | ||||
|     } | ||||
| 
 | ||||
|     pub fn get(&self, key: &K) -> Vec<&V> { | ||||
|         self.map | ||||
|             .iter() | ||||
|             .filter_map(|(k, v)| if k == key { Some(v) } else { None }) | ||||
|             .collect() | ||||
|     } | ||||
| 
 | ||||
|     pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) { | ||||
|         self.map.iter_mut().for_each(|(k, _)| f(k)) | ||||
|     } | ||||
| 
 | ||||
|     pub fn map<F: Fn(&mut (K, V))>(&mut self, f: F) { | ||||
|         self.map.iter_mut().for_each(f); | ||||
|     } | ||||
| 
 | ||||
|     pub fn try_map<E, F: Fn((K, V)) -> Result<(K, V), E>>(&mut self, f: F) -> Result<(), E> { | ||||
|         let vec = mem::take(&mut self.map); | ||||
|         let vec = vec.into_iter().map(f).collect::<Result<Vec<(K, V)>, _>>()?; | ||||
|         self.map = vec; | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     pub fn find<F: FnMut(&&(K, V)) -> bool>(&self, f: F) -> Option<&(K, V)> { | ||||
|         self.map.iter().find(f) | ||||
|     } | ||||
| 
 | ||||
|     pub fn find_remove<F: FnMut(&&(K, V)) -> bool>(&mut self, mut f: F) -> Option<(K, V)> { | ||||
|         let idx = self | ||||
|             .map | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .find_map(|(n, ref kv)| if f(kv) { Some(n) } else { None })?; | ||||
|         Some(self.map.swap_remove(idx)) | ||||
|     } | ||||
| 
 | ||||
|     pub fn len(&self) -> usize { | ||||
|         self.map.len() | ||||
|     } | ||||
| 
 | ||||
|     pub fn is_empty(&self) -> bool { | ||||
|         self.map.is_empty() | ||||
|     } | ||||
| 
 | ||||
|     pub fn pop(&mut self) -> Option<(K, V)> { | ||||
|         self.map.pop() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<K: Eq, V> Default for MultiMap<K, V> { | ||||
|     fn default() -> Self { | ||||
|         Self { | ||||
|             map: Vec::default(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<K: Eq + Clone, V: Clone> Clone for MultiMap<K, V> { | ||||
|     fn clone(&self) -> Self { | ||||
|         Self { | ||||
|             map: self.map.clone(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<K: Eq + Debug, V: Debug> Debug for MultiMap<K, V> { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         f.debug_map() | ||||
|             .entries(self.map.iter().map(|&(ref k, ref v)| (k, v))) | ||||
|             .finish() | ||||
|     } | ||||
| } | ||||
					Loading…
					
					
				
		Reference in new issue