parent
							
								
									134f803b88
								
							
						
					
					
						commit
						6a46c6ca52
					
				| @ -1,37 +1,75 @@ | ||||
| use std::{collections::HashMap, error::Error, rc::Rc}; | ||||
| use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc}; | ||||
| 
 | ||||
| use crate::{Ast, Constant, Ident, PrimitiveType, Type}; | ||||
| use crate::{ | ||||
|     DeBrujinAst, Ident, PrimitiveType, Type, | ||||
|     parse::{Ast, Constant}, | ||||
|     types::{TaggedType, TypeTag}, | ||||
|     vec_map::VecMap, | ||||
| }; | ||||
| 
 | ||||
| #[derive(Debug)] | ||||
| pub enum InferError { | ||||
|     NotAFunction, | ||||
|     MismatchedType, | ||||
|     NotInContext, | ||||
|     ExpectedConreteType, | ||||
|     ExpectedTypeWithTag, | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod test; | ||||
| 
 | ||||
| pub fn infer_type(mut gamma: Rc<HashMap<Ident, Type>>, ast: Ast) -> Result<Type, InferError> { | ||||
| pub fn infer_type( | ||||
|     gamma: &HashMap<Ident, TaggedType>, | ||||
|     ast: DeBrujinAst, | ||||
| ) -> Result<TaggedType, InferError> { | ||||
|     infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast) | ||||
| } | ||||
| 
 | ||||
| fn infer_type_debrujin_int( | ||||
|     gamma_free: &HashMap<Ident, TaggedType>, | ||||
|     mut gamma_bound: Rc<VecMap<usize, TaggedType>>, | ||||
|     ast: DeBrujinAst, | ||||
| ) -> Result<TaggedType, InferError> { | ||||
|     match ast { | ||||
|         Ast::Abstraction(arg, arg_type, ast) => { | ||||
|             Rc::make_mut(&mut gamma).insert(arg, arg_type.clone()); | ||||
|             let out_type = infer_type(gamma, *ast)?; | ||||
|             Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type))) | ||||
|         DeBrujinAst::Abstraction(_, arg_type, ast) => { | ||||
|             let gamma_ref = Rc::make_mut(&mut gamma_bound); | ||||
|             gamma_ref.map_keys(|i| *i += 1); | ||||
|             gamma_ref.insert(1, arg_type.clone().into()); | ||||
| 
 | ||||
|             let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?; | ||||
|             let typ = Type::Arrow( | ||||
|                 Box::new(arg_type), | ||||
|                 Box::new( | ||||
|                     out_type | ||||
|                         .to_concrete() | ||||
|                         .ok_or(InferError::ExpectedConreteType)?, | ||||
|                 ), | ||||
|             ); | ||||
|             Ok(typ.into()) | ||||
|         } | ||||
|         Ast::Application(left, right) => { | ||||
|             let left_type = infer_type(gamma.clone(), *left)?; | ||||
|             let Type::Arrow(in_type, out_type) = left_type else { | ||||
|         DeBrujinAst::Application(lhs, rhs) => { | ||||
|             let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?; | ||||
|             let Some((in_type, out_type)) = left_type.arrow() else { | ||||
|                 return Err(InferError::NotAFunction); | ||||
|             }; | ||||
|             let right_type = infer_type(gamma, *right)?; | ||||
|             if *in_type != right_type { | ||||
|                 return Err(InferError::MismatchedType); | ||||
|             } | ||||
|             Ok(*out_type) | ||||
|             let Some(right_type) = | ||||
|                 infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete() | ||||
|             else { | ||||
|                 return Err(InferError::ExpectedConreteType); | ||||
|             }; | ||||
|             let out_type = match in_type.matches(&right_type) { | ||||
|                 (true, None) => out_type, | ||||
|                 (true, Some(a)) => out_type.specialize(&a, &right_type), | ||||
|                 (false, _) => return Err(InferError::MismatchedType), | ||||
|             }; | ||||
|             Ok(out_type) | ||||
|         } | ||||
|         Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext), | ||||
|         Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)), | ||||
|         Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)), | ||||
|         DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext), | ||||
|         // compiler bug if not present
 | ||||
|         DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()), | ||||
|         DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()), | ||||
|         DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()), | ||||
|         DeBrujinAst::Constant(_) => unreachable!(), | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -1,49 +1,73 @@ | ||||
| use std::{collections::HashMap, rc::Rc}; | ||||
| 
 | ||||
| use crate::{Ast, Constant, PrimitiveType, Type}; | ||||
| 
 | ||||
| use super::infer_type; | ||||
| use crate::{ | ||||
|     DeBrujinAst, PrimitiveType, Type, | ||||
|     inference::infer_type, | ||||
|     parse::{Ast, Constant}, | ||||
|     types::{TaggedType, TypeTag}, | ||||
| }; | ||||
| 
 | ||||
| #[test] | ||||
| fn infer_id_type() { | ||||
|     let ast = Ast::Abstraction( | ||||
|         "x".to_string(), | ||||
|         Type::Primitive(PrimitiveType::Nat), | ||||
|         Box::new(Ast::Variable("x".to_string())), | ||||
| fn infer_add_nat() { | ||||
|     let ast: DeBrujinAst = Ast::Application( | ||||
|         Box::new(Ast::Application( | ||||
|             Box::new(Ast::Variable("add".to_string())), | ||||
|             Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|         )), | ||||
|         Box::new(Ast::Constant(Constant::Nat(7))), | ||||
|     ) | ||||
|     .into(); | ||||
| 
 | ||||
|     let mut gamma = HashMap::new(); | ||||
|     gamma.insert( | ||||
|         "add".to_string(), | ||||
|         crate::types::TaggedType::Tagged( | ||||
|             TypeTag::Num, | ||||
|             "a".to_string(), | ||||
|             Box::new(TaggedType::Concrete(Type::Arrow( | ||||
|                 Box::new(Type::Generic("a".to_string())), | ||||
|                 Box::new(Type::Arrow( | ||||
|                     Box::new(Type::Generic("a".to_string())), | ||||
|                     Box::new(Type::Generic("a".to_string())), | ||||
|                 )), | ||||
|             ))), | ||||
|         ), | ||||
|     ); | ||||
| 
 | ||||
|     let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap(); | ||||
|     let infered = infer_type(&gamma, ast).unwrap(); | ||||
|     assert_eq!( | ||||
|         infered, | ||||
|         Type::Arrow( | ||||
|             Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|             Box::new(Type::Primitive(PrimitiveType::Nat)) | ||||
|         ) | ||||
|     ) | ||||
|         TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat)) | ||||
|     ); | ||||
| } | ||||
| 
 | ||||
| #[test] | ||||
| fn infer_addition_result_type() { | ||||
|     let ast = Ast::Application( | ||||
|         Box::new(Ast::Application( | ||||
|             Box::new(Ast::Variable("add".to_string())), | ||||
|             Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|         )), | ||||
|         Box::new(Ast::Constant(Constant::Nat(7))), | ||||
|     ); | ||||
| fn infer_add_nat_partial() { | ||||
|     let ast: DeBrujinAst = Ast::Application( | ||||
|         Box::new(Ast::Variable("add".to_string())), | ||||
|         Box::new(Ast::Constant(Constant::Nat(5))), | ||||
|     ) | ||||
|     .into(); | ||||
| 
 | ||||
|     let mut gamma = HashMap::new(); | ||||
|     gamma.insert( | ||||
|         "add".to_string(), | ||||
|         Type::Arrow( | ||||
|             Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|             Box::new(Type::Arrow( | ||||
|                 Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|                 Box::new(Type::Primitive(PrimitiveType::Nat)), | ||||
|             )), | ||||
|         crate::types::TaggedType::Tagged( | ||||
|             TypeTag::Num, | ||||
|             "a".to_string(), | ||||
|             Box::new(TaggedType::Concrete(Type::Arrow( | ||||
|                 Box::new(Type::Generic("a".to_string())), | ||||
|                 Box::new(Type::Arrow( | ||||
|                     Box::new(Type::Generic("a".to_string())), | ||||
|                     Box::new(Type::Generic("a".to_string())), | ||||
|                 )), | ||||
|             ))), | ||||
|         ), | ||||
|     ); | ||||
| 
 | ||||
|     let infered = infer_type(Rc::new(gamma), ast).unwrap(); | ||||
|     assert_eq!(infered, Type::Primitive(PrimitiveType::Nat)); | ||||
|     let infered = infer_type(&gamma, ast).unwrap(); | ||||
|     assert_eq!( | ||||
|         infered, | ||||
|         TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) | ||||
|     ); | ||||
| } | ||||
|  | ||||
| @ -0,0 +1,211 @@ | ||||
| use std::{fmt::Display, str}; | ||||
| 
 | ||||
| pub type Ident = String; | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| pub enum PrimitiveType { | ||||
|     Nat, | ||||
|     Bool, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| pub enum TypeTag { | ||||
|     Num, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| pub enum TaggedType { | ||||
|     Tagged(TypeTag, Ident, Box<TaggedType>), | ||||
|     Concrete(Type), | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||||
| pub enum Type { | ||||
|     Generic(Ident),              // a
 | ||||
|     Primitive(PrimitiveType),    // Bool | Nat
 | ||||
|     Arrow(Box<Type>, Box<Type>), // 0 -> 1
 | ||||
| } | ||||
| 
 | ||||
| impl From<Type> for TaggedType { | ||||
|     fn from(value: Type) -> Self { | ||||
|         Self::Concrete(value) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl TaggedType { | ||||
|     pub fn r#type(&self) -> &Type { | ||||
|         match self { | ||||
|             TaggedType::Tagged(_, _, tagged_type) => tagged_type.r#type(), | ||||
|             TaggedType::Concrete(t) => t, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn type_mut(&mut self) -> &mut Type { | ||||
|         match self { | ||||
|             TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_mut(), | ||||
|             TaggedType::Concrete(t) => t, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn to_concrete(self) -> Option<Type> { | ||||
|         match self { | ||||
|             TaggedType::Tagged(type_tag, _, tagged_type) => None, | ||||
|             TaggedType::Concrete(t) => { | ||||
|                 if t.is_concrete() { | ||||
|                     Some(t) | ||||
|                 } else { | ||||
|                     None | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn matches(&self, typ: &Type) -> (bool, Option<Ident>) { | ||||
|         match self { | ||||
|             TaggedType::Tagged(type_tag, ident, tagged_type) => { | ||||
|                 if typ.has_tag(type_tag) { | ||||
|                     (true, Some(ident.clone())) | ||||
|                 } else { | ||||
|                     tagged_type.matches(typ) | ||||
|                 } | ||||
|             } | ||||
|             TaggedType::Concrete(c) => (c == typ, None), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn arrow(self) -> Option<(TaggedType, TaggedType)> { | ||||
|         match self { | ||||
|             TaggedType::Tagged(type_tag, ident, tagged_type) => { | ||||
|                 let (lhs, rhs) = tagged_type.arrow()?; | ||||
|                 Some(( | ||||
|                     TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs)) | ||||
|                         .clear_unused_names(), | ||||
|                     TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(), | ||||
|                 )) | ||||
|             } | ||||
|             TaggedType::Concrete(Type::Arrow(lhs, rhs)) => { | ||||
|                 Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs))) | ||||
|             } | ||||
|             TaggedType::Concrete(_) => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType { | ||||
|         match self { | ||||
|             TaggedType::Tagged(_, i, tagged_type) if i == ident => { | ||||
|                 tagged_type.specialize(ident, typ) | ||||
|             } | ||||
|             TaggedType::Tagged(t, i, tagged_type) => { | ||||
|                 TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ))) | ||||
|             } | ||||
|             TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn name_used(&self, ident: &str) -> bool { | ||||
|         match self { | ||||
|             TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident), | ||||
|             TaggedType::Concrete(c) => c.name_used(ident), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn clear_unused_names(self) -> TaggedType { | ||||
|         match self { | ||||
|             TaggedType::Tagged(type_tag, ident, tagged_type) => { | ||||
|                 if tagged_type.name_used(&ident) { | ||||
|                     TaggedType::Tagged(type_tag, ident, tagged_type) | ||||
|                 } else { | ||||
|                     *tagged_type | ||||
|                 } | ||||
|             } | ||||
|             t => t, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<PrimitiveType> for Type { | ||||
|     fn from(value: PrimitiveType) -> Self { | ||||
|         Self::Primitive(value) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Type { | ||||
|     pub fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self { | ||||
|         Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) | ||||
|     } | ||||
| 
 | ||||
|     pub fn is_concrete(&self) -> bool { | ||||
|         match self { | ||||
|             Type::Generic(_) => false, | ||||
|             Type::Primitive(primitive_type) => true, | ||||
|             Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn has_tag(&self, tag: &TypeTag) -> bool { | ||||
|         match self { | ||||
|             Type::Generic(_) => false, | ||||
|             Type::Primitive(primitive_type) => match (primitive_type, tag) { | ||||
|                 (PrimitiveType::Nat, TypeTag::Num) => true, | ||||
|                 _ => false, | ||||
|             }, | ||||
|             Type::Arrow(_, _) => false, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn name_used(&self, ident: &str) -> bool { | ||||
|         match self { | ||||
|             Type::Generic(i) if *i == ident => true, | ||||
|             _ => false, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn specialize(self, ident: &str, typ: &Type) -> Type { | ||||
|         match self { | ||||
|             Type::Generic(i) if i == ident => typ.clone(), | ||||
|             Type::Arrow(lhs, rhs) => Type::Arrow( | ||||
|                 Box::new(lhs.specialize(ident, typ)), | ||||
|                 Box::new(rhs.specialize(ident, typ)), | ||||
|             ), | ||||
|             t => t, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for TypeTag { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             TypeTag::Num => write!(f, "Num"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for TaggedType { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             TaggedType::Tagged(type_tag, ident, tagged_type) => { | ||||
|                 write!(f, "{type_tag} {ident} => {tagged_type}") | ||||
|             } | ||||
|             TaggedType::Concrete(t) => write!(f, "{t}"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for Type { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             Type::Generic(i) => write!(f, "{i}"), | ||||
|             Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), | ||||
|             Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Display for PrimitiveType { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         match self { | ||||
|             PrimitiveType::Nat => write!(f, "Nat"), | ||||
|             PrimitiveType::Bool => write!(f, "Bool"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -0,0 +1,38 @@ | ||||
| use std::mem; | ||||
| 
 | ||||
| pub struct VecMap<K: Eq, V> { | ||||
|     map: Vec<(K, V)>, | ||||
| } | ||||
| 
 | ||||
| impl<K: Eq, V> VecMap<K, V> { | ||||
|     pub fn new() -> Self { | ||||
|         Self { map: Vec::new() } | ||||
|     } | ||||
| 
 | ||||
|     pub fn insert(&mut self, key: K, val: V) -> Option<(K, V)> { | ||||
|         if let Some(entry) = self.map.iter_mut().find(|(k, _)| *k == key) { | ||||
|             Some(mem::replace(entry, (key, val))) | ||||
|         } else { | ||||
|             self.map.push((key, val)); | ||||
|             None | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     pub fn get(&self, key: &K) -> Option<&V> { | ||||
|         self.map | ||||
|             .iter() | ||||
|             .find_map(|(k, v)| if k == key { Some(v) } else { None }) | ||||
|     } | ||||
| 
 | ||||
|     pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) { | ||||
|         self.map.iter_mut().for_each(|(k, _)| f(k)) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<K: Eq + Clone, V: Clone> Clone for VecMap<K, V> { | ||||
|     fn clone(&self) -> Self { | ||||
|         Self { | ||||
|             map: self.map.clone(), | ||||
|         } | ||||
|     } | ||||
| } | ||||
					Loading…
					
					
				
		Reference in new issue