parent
305e95846d
commit
766da2593e
@ -0,0 +1,59 @@
|
||||
use std::{collections::HashMap, rc::Rc};
|
||||
|
||||
use crate::{
|
||||
DeBrujinAst,
|
||||
parse::Constant,
|
||||
types::{Ident, PrimitiveType, TaggedType, Type},
|
||||
vec_map::VecMap,
|
||||
};
|
||||
|
||||
use super::InferError;
|
||||
|
||||
pub fn infer_type(
|
||||
gamma: &HashMap<Ident, TaggedType>,
|
||||
ast: DeBrujinAst,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast)
|
||||
}
|
||||
|
||||
fn infer_type_debrujin_int(
|
||||
gamma_free: &HashMap<Ident, TaggedType>,
|
||||
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
|
||||
ast: DeBrujinAst,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
match ast {
|
||||
DeBrujinAst::Abstraction(_, arg_type, ast) => {
|
||||
let gamma_ref = Rc::make_mut(&mut gamma_bound);
|
||||
gamma_ref.map_keys(|i| *i += 1);
|
||||
gamma_ref.insert(1, arg_type.clone().unwrap().into());
|
||||
|
||||
let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?;
|
||||
// TODO: Fix this hack
|
||||
let typ = arg_type.unwrap().make_arrow(out_type);
|
||||
Ok(typ)
|
||||
}
|
||||
DeBrujinAst::Application(lhs, rhs) => {
|
||||
let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?;
|
||||
let Some((in_type, out_type)) = left_type.arrow() else {
|
||||
return Err(InferError::NotAFunction);
|
||||
};
|
||||
let Some(right_type) =
|
||||
infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete()
|
||||
else {
|
||||
return Err(InferError::ExpectedConreteType);
|
||||
};
|
||||
let out_type = match in_type.matches(&right_type) {
|
||||
(true, None) => out_type,
|
||||
(true, Some(a)) => out_type.specialize(&a, &right_type),
|
||||
(false, _) => return Err(InferError::MismatchedType),
|
||||
};
|
||||
Ok(out_type)
|
||||
}
|
||||
DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext),
|
||||
// compiler bug if not present
|
||||
DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()),
|
||||
DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()),
|
||||
DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()),
|
||||
DeBrujinAst::Constant(_) => unreachable!(),
|
||||
}
|
||||
}
|
@ -0,0 +1,329 @@
|
||||
use std::{cell::RefCell, clone, collections::HashMap, rc::Rc};
|
||||
|
||||
use crate::{
|
||||
Ast, DeBrujinAst,
|
||||
multi_map::MultiMap,
|
||||
parse::Constant,
|
||||
types::{Ident, PrimitiveType, TaggedType, Type, TypeTag},
|
||||
vec_map::VecMap,
|
||||
};
|
||||
|
||||
use super::InferError;
|
||||
|
||||
type TypeVar = TaggedType;
|
||||
|
||||
pub struct TypeVarCtx {
|
||||
counter: RefCell<usize>,
|
||||
}
|
||||
|
||||
impl TypeVarCtx {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
counter: RefCell::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_var(&self) -> TaggedType {
|
||||
let mut num = self.counter.borrow_mut();
|
||||
let res = format!("?typ_{num}");
|
||||
*num += 1;
|
||||
Type::TypeVariable(res).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum TypeVarAst {
|
||||
Abstraction(TypeVar, Ident, TaggedType, Box<TypeVarAst>), // \0:1.2
|
||||
Application(TypeVar, Box<TypeVarAst>, Box<TypeVarAst>), // 0 1
|
||||
FreeVariable(TypeVar, Ident), // x
|
||||
BoundVariable(TypeVar, usize), // 1
|
||||
Constant(TypeVar, Constant), // true | false | n
|
||||
}
|
||||
|
||||
pub type Constraints = MultiMap<TypeVar, TaggedType>;
|
||||
|
||||
pub(super) fn step_1(
|
||||
ast: DeBrujinAst,
|
||||
gamma_free: &HashMap<Ident, TaggedType>,
|
||||
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
|
||||
constraints: Rc<RefCell<Constraints>>,
|
||||
ctx: &TypeVarCtx,
|
||||
) -> Result<(TypeVarAst, TypeVar), InferError> {
|
||||
match ast {
|
||||
DeBrujinAst::Abstraction(i, Some(typ), ast) => {
|
||||
let var = ctx.get_var();
|
||||
|
||||
let gamma_ref = Rc::make_mut(&mut gamma_bound);
|
||||
gamma_ref.map_keys(|i| *i += 1);
|
||||
gamma_ref.insert(1, typ.clone());
|
||||
|
||||
let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?;
|
||||
// RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var));
|
||||
Ok((
|
||||
TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)),
|
||||
var,
|
||||
))
|
||||
}
|
||||
DeBrujinAst::Abstraction(i, None, ast) => {
|
||||
let var = ctx.get_var();
|
||||
|
||||
let typ = ctx.get_var();
|
||||
let gamma_ref = Rc::make_mut(&mut gamma_bound);
|
||||
gamma_ref.map_keys(|i| *i += 1);
|
||||
gamma_ref.insert(1, typ.clone());
|
||||
|
||||
let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?;
|
||||
// RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var));
|
||||
Ok((
|
||||
TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)),
|
||||
var,
|
||||
))
|
||||
}
|
||||
DeBrujinAst::Application(lhs, rhs) => {
|
||||
let var = ctx.get_var();
|
||||
let (lhs, lhs_var) = step_1(
|
||||
*lhs,
|
||||
gamma_free,
|
||||
gamma_bound.clone(),
|
||||
constraints.clone(),
|
||||
ctx,
|
||||
)?;
|
||||
let (rhs, rhs_var) = step_1(
|
||||
*rhs,
|
||||
gamma_free,
|
||||
gamma_bound.clone(),
|
||||
constraints.clone(),
|
||||
ctx,
|
||||
)?;
|
||||
RefCell::borrow_mut(&constraints).insert(lhs_var, rhs_var.make_arrow(var.clone()));
|
||||
Ok((
|
||||
TypeVarAst::Application(var.clone(), Box::new(lhs), Box::new(rhs)),
|
||||
var,
|
||||
))
|
||||
}
|
||||
DeBrujinAst::FreeVariable(v) => {
|
||||
let var = ctx.get_var();
|
||||
let typ = gamma_free
|
||||
.get(&v)
|
||||
.cloned()
|
||||
.ok_or(InferError::NotInContext)?;
|
||||
|
||||
RefCell::borrow_mut(&constraints).insert(var.clone(), typ);
|
||||
|
||||
Ok((TypeVarAst::FreeVariable(var.clone(), v), var))
|
||||
}
|
||||
DeBrujinAst::BoundVariable(i) => {
|
||||
let var = ctx.get_var();
|
||||
let typ = gamma_bound
|
||||
.get(&i)
|
||||
.cloned()
|
||||
.ok_or(InferError::NotInContext)?;
|
||||
|
||||
RefCell::borrow_mut(&constraints).insert(var.clone(), typ);
|
||||
|
||||
Ok((TypeVarAst::BoundVariable(var.clone(), i), var))
|
||||
}
|
||||
DeBrujinAst::Constant(constant) => {
|
||||
let var = ctx.get_var();
|
||||
let typ = match constant {
|
||||
Constant::Nat(_) => Type::Primitive(PrimitiveType::Nat),
|
||||
Constant::Float(_) => Type::Primitive(PrimitiveType::Float),
|
||||
Constant::Bool(_) => Type::Primitive(PrimitiveType::Bool),
|
||||
}
|
||||
.into();
|
||||
|
||||
RefCell::borrow_mut(&constraints).insert(var.clone(), typ);
|
||||
|
||||
Ok((TypeVarAst::Constant(var.clone(), constant), var))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn step_2(
|
||||
mut constraints: Constraints,
|
||||
) -> Result<Option<Box<dyn FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError>>>, InferError> {
|
||||
if let Some((s, t)) = constraints.pop() {
|
||||
if s == t {
|
||||
step_2(constraints)
|
||||
} else if s.type_var().is_some_and(|x| !t.name_used(&x)) {
|
||||
let Some(x) = s.type_var() else {
|
||||
unreachable!()
|
||||
};
|
||||
constraints.try_map(|(k, v)| {
|
||||
Ok((
|
||||
k.subst_typevar(&x, 0, t.clone())?,
|
||||
v.subst_typevar(&x, 0, t.clone())?,
|
||||
))
|
||||
})?;
|
||||
let subst = step_2(constraints)?;
|
||||
Ok(Some(subst_comp(x, t, subst)))
|
||||
} else if t.type_var().is_some_and(|x| !s.name_used(&x)) {
|
||||
let Some(x) = t.type_var() else {
|
||||
unreachable!()
|
||||
};
|
||||
constraints.try_map(|(k, v)| {
|
||||
Ok((
|
||||
k.subst_typevar(&x, 0, s.clone())?,
|
||||
v.subst_typevar(&x, 0, s.clone())?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let subst = step_2(constraints)?;
|
||||
Ok(Some(subst_comp(x, t, subst)))
|
||||
} else if let (Some((s_lhs, s_rhs)), Some((t_lhs, t_rhs))) = (s.arrow(), t.arrow()) {
|
||||
constraints.insert(s_lhs, t_lhs);
|
||||
constraints.insert(s_rhs, t_rhs);
|
||||
step_2(constraints)
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeVarAst {
|
||||
pub fn subst(self, var: &str, subst: TaggedType) -> Result<TypeVarAst, InferError> {
|
||||
match self {
|
||||
TypeVarAst::Abstraction(tagged_type1, ident, tagged_type2, ast) => {
|
||||
Ok(TypeVarAst::Abstraction(
|
||||
tagged_type1.subst_typevar(var, 0, subst.clone())?,
|
||||
ident,
|
||||
tagged_type2.subst_typevar(var, 0, subst.clone())?,
|
||||
Box::new(ast.subst(var, subst)?),
|
||||
))
|
||||
}
|
||||
TypeVarAst::Application(tagged_type, lhs, rhs) => Ok(TypeVarAst::Application(
|
||||
tagged_type.subst_typevar(var, 0, subst.clone())?,
|
||||
Box::new(lhs.subst(var, subst.clone())?),
|
||||
Box::new(rhs.subst(var, subst)?),
|
||||
)),
|
||||
TypeVarAst::FreeVariable(tagged_type, x) => Ok(TypeVarAst::FreeVariable(
|
||||
tagged_type.subst_typevar(var, 0, subst)?,
|
||||
x,
|
||||
)),
|
||||
TypeVarAst::BoundVariable(tagged_type, i) => Ok(TypeVarAst::BoundVariable(
|
||||
tagged_type.subst_typevar(var, 0, subst)?,
|
||||
i,
|
||||
)),
|
||||
TypeVarAst::Constant(tagged_type, constant) => Ok(TypeVarAst::Constant(
|
||||
tagged_type.subst_typevar(var, 0, subst)?,
|
||||
constant,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subst_comp<'a, F>(
|
||||
var: String,
|
||||
subst: TaggedType,
|
||||
then: Option<F>,
|
||||
) -> Box<dyn FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError> + 'a>
|
||||
where
|
||||
F: FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError> + 'a,
|
||||
{
|
||||
Box::new(move |ast: TypeVarAst| -> Result<TypeVarAst, InferError> {
|
||||
let ast = ast.subst(&var, subst)?;
|
||||
if let Some(f) = then { f(ast) } else { Ok(ast) }
|
||||
})
|
||||
}
|
||||
|
||||
pub fn infer_type(
|
||||
gamma: &HashMap<Ident, TaggedType>,
|
||||
ast: DeBrujinAst,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
let gamma_bound = Rc::new(VecMap::new());
|
||||
let constraints = Rc::new(RefCell::new(MultiMap::new()));
|
||||
let ctx = TypeVarCtx::new();
|
||||
|
||||
let (ast, _) = step_1(ast, gamma, gamma_bound, constraints.clone(), &ctx)?;
|
||||
constraints.clone();
|
||||
let res = step_2(constraints.take())?.unwrap();
|
||||
|
||||
let ast = res(ast)?;
|
||||
|
||||
fn get_type(ast: TypeVarAst) -> TaggedType {
|
||||
match ast {
|
||||
TypeVarAst::Abstraction(_, _, typ, ast) => typ.make_arrow(get_type(*ast)),
|
||||
TypeVarAst::Application(tagged_type, _, _) => tagged_type,
|
||||
TypeVarAst::FreeVariable(tagged_type, _) => tagged_type,
|
||||
TypeVarAst::BoundVariable(tagged_type, _) => tagged_type,
|
||||
TypeVarAst::Constant(tagged_type, constant) => tagged_type,
|
||||
}
|
||||
}
|
||||
|
||||
let typ = get_type(ast);
|
||||
|
||||
let mut typ = typ;
|
||||
|
||||
for free_var in typ.free_vars() {
|
||||
typ = TaggedType::Tagged(TypeTag::Any, free_var, Box::new(typ));
|
||||
}
|
||||
|
||||
Ok(typ.normalise())
|
||||
}
|
||||
|
||||
impl TaggedType {
|
||||
pub fn subst_typevar(
|
||||
self,
|
||||
var: &str,
|
||||
diversifier: usize,
|
||||
mut subst: TaggedType,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
match self {
|
||||
TaggedType::Tagged(t, i, tagged_type) => {
|
||||
if subst.name_used(&i) {
|
||||
subst.map_name(|f| {
|
||||
if *f == i {
|
||||
*f = format!("{i}+{diversifier}")
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(TaggedType::Tagged(
|
||||
t,
|
||||
i,
|
||||
Box::new(tagged_type.subst_typevar(var, diversifier + 1, subst)?),
|
||||
))
|
||||
}
|
||||
TaggedType::Concrete(t) => subst.subst_into(var, diversifier, t),
|
||||
}
|
||||
}
|
||||
|
||||
fn subst_into(
|
||||
self,
|
||||
var: &str,
|
||||
diversifier: usize,
|
||||
mut target: Type,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
match self {
|
||||
TaggedType::Tagged(t, i, tagged_type) => {
|
||||
if target.name_used(&i) {
|
||||
target.map_name(&|f| {
|
||||
if *f == i {
|
||||
*f = format!("{i}+{diversifier}")
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(TaggedType::Tagged(
|
||||
t,
|
||||
i,
|
||||
Box::new(tagged_type.subst_into(var, diversifier + 1, target)?),
|
||||
))
|
||||
}
|
||||
TaggedType::Concrete(c) => Ok(TaggedType::Concrete(target.subst_typevar(var, &c)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn subst_typevar(self, var: &str, subst: &Type) -> Result<Type, InferError> {
|
||||
match self {
|
||||
Type::TypeVariable(v) if v == var => Ok(subst.clone()),
|
||||
Type::Arrow(lhs, rhs) => Ok(Type::Arrow(
|
||||
Box::new(lhs.subst_typevar(var, subst)?),
|
||||
Box::new(rhs.subst_typevar(var, subst)?),
|
||||
)),
|
||||
t => Ok(t),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
use std::{fmt::Debug, mem};
|
||||
|
||||
pub struct MultiMap<K: Eq, V> {
|
||||
map: Vec<(K, V)>,
|
||||
}
|
||||
|
||||
impl<K: Eq, V> MultiMap<K, V> {
|
||||
pub fn new() -> Self {
|
||||
Self { map: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: K, val: V) {
|
||||
self.map.push((key, val));
|
||||
}
|
||||
|
||||
pub fn get(&self, key: &K) -> Vec<&V> {
|
||||
self.map
|
||||
.iter()
|
||||
.filter_map(|(k, v)| if k == key { Some(v) } else { None })
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) {
|
||||
self.map.iter_mut().for_each(|(k, _)| f(k))
|
||||
}
|
||||
|
||||
pub fn map<F: Fn(&mut (K, V))>(&mut self, f: F) {
|
||||
self.map.iter_mut().for_each(f);
|
||||
}
|
||||
|
||||
pub fn try_map<E, F: Fn((K, V)) -> Result<(K, V), E>>(&mut self, f: F) -> Result<(), E> {
|
||||
let vec = mem::take(&mut self.map);
|
||||
let vec = vec.into_iter().map(f).collect::<Result<Vec<(K, V)>, _>>()?;
|
||||
self.map = vec;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn find<F: FnMut(&&(K, V)) -> bool>(&self, f: F) -> Option<&(K, V)> {
|
||||
self.map.iter().find(f)
|
||||
}
|
||||
|
||||
pub fn find_remove<F: FnMut(&&(K, V)) -> bool>(&mut self, mut f: F) -> Option<(K, V)> {
|
||||
let idx = self
|
||||
.map
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(n, ref kv)| if f(kv) { Some(n) } else { None })?;
|
||||
Some(self.map.swap_remove(idx))
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.map.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.map.is_empty()
|
||||
}
|
||||
|
||||
pub fn pop(&mut self) -> Option<(K, V)> {
|
||||
self.map.pop()
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq, V> Default for MultiMap<K, V> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
map: Vec::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq + Clone, V: Clone> Clone for MultiMap<K, V> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
map: self.map.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq + Debug, V: Debug> Debug for MultiMap<K, V> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_map()
|
||||
.entries(self.map.iter().map(|&(ref k, ref v)| (k, v)))
|
||||
.finish()
|
||||
}
|
||||
}
|
Loading…
Reference in new issue