Inference by unification

main
Avery 1 month ago
parent 305e95846d
commit 766da2593e
Signed by: Avery
GPG Key ID: 4E53F4CB69B2CC8D

@ -9,11 +9,11 @@ use super::DeBrujinAst;
#[derive(Clone)] #[derive(Clone)]
pub enum DeBrujinBuiltInAst { pub enum DeBrujinBuiltInAst {
Abstraction(Ident, TaggedType, Box<DeBrujinBuiltInAst>), // \:1.2 Abstraction(Ident, Option<TaggedType>, Box<DeBrujinBuiltInAst>), // \:1.2
Application(Box<DeBrujinBuiltInAst>, Box<DeBrujinBuiltInAst>), // 0 1 Application(Box<DeBrujinBuiltInAst>, Box<DeBrujinBuiltInAst>), // 0 1
FreeVariable(String), // x FreeVariable(String), // x
BoundVariable(usize), // 1 BoundVariable(usize), // 1
Constant(Constant), // true | false | n Constant(Constant), // true | false | n
Builtin(Rc<dyn Builtin>), Builtin(Rc<dyn Builtin>),
} }

@ -48,11 +48,11 @@ impl Ast {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum DeBrujinAst { pub enum DeBrujinAst {
Abstraction(Ident, TaggedType, Box<DeBrujinAst>), // \:1.2 Abstraction(Ident, Option<TaggedType>, Box<DeBrujinAst>), // \:1.2
Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1 Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1
FreeVariable(String), // x FreeVariable(String), // x
BoundVariable(usize), // 1 BoundVariable(usize), // 1
Constant(Constant), // true | false | n Constant(Constant), // true | false | n
} }
impl Into<Ast> for DeBrujinAst { impl Into<Ast> for DeBrujinAst {

@ -12,10 +12,10 @@ use super::builtins::Builtin;
fn to_de_brujin_ast_simple() { fn to_de_brujin_ast_simple() {
let input = Ast::Abstraction( let input = Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Variable("x".to_string())),
)), )),
); );
@ -24,10 +24,10 @@ fn to_de_brujin_ast_simple() {
de_brujin, de_brujin,
DBAst::Abstraction( DBAst::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(DBAst::Abstraction( Box::new(DBAst::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(DBAst::BoundVariable(1)) Box::new(DBAst::BoundVariable(1))
)) ))
) )
@ -39,7 +39,7 @@ fn de_brujin_beta_reduce() {
let input = Ast::Application( let input = Ast::Application(
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into(), Some(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()),
Box::new(Ast::Application( Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))), Box::new(Ast::Constant(Constant::Nat(5))),
@ -63,7 +63,7 @@ fn to_and_from_de_brujin_is_id() {
let input = Ast::Application( let input = Ast::Application(
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into(), Some(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()),
Box::new(Ast::Application( Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))), Box::new(Ast::Constant(Constant::Nat(5))),

@ -1,11 +1,11 @@
use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc}; use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc};
use crate::{ use crate::types::{TaggedType, TypeTag};
DeBrujinAst,
parse::{Ast, Constant}, pub mod recursive;
types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, #[cfg(test)]
vec_map::VecMap, mod test;
}; pub mod unification;
#[derive(Debug)] #[derive(Debug)]
pub enum InferError { pub enum InferError {
@ -14,60 +14,6 @@ pub enum InferError {
NotInContext, NotInContext,
ExpectedConreteType, ExpectedConreteType,
ExpectedTypeWithTag, ExpectedTypeWithTag,
} DoesNotFitTag(TypeTag, TaggedType),
ConfilictingBind,
#[cfg(test)]
mod test;
pub fn infer_type(
gamma: &HashMap<Ident, TaggedType>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast)
}
fn infer_type_debrujin_int(
gamma_free: &HashMap<Ident, TaggedType>,
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
match ast {
DeBrujinAst::Abstraction(_, arg_type, ast) => {
let gamma_ref = Rc::make_mut(&mut gamma_bound);
gamma_ref.map_keys(|i| *i += 1);
gamma_ref.insert(1, arg_type.clone().into());
let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?;
// TODO: Fix this hack
let typ = arg_type.make_arrow(
out_type
.to_concrete()
.ok_or(InferError::ExpectedConreteType)?,
);
Ok(typ)
}
DeBrujinAst::Application(lhs, rhs) => {
let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?;
let Some((in_type, out_type)) = left_type.arrow() else {
return Err(InferError::NotAFunction);
};
let Some(right_type) =
infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete()
else {
return Err(InferError::ExpectedConreteType);
};
let out_type = match in_type.matches(&right_type) {
(true, None) => out_type,
(true, Some(a)) => out_type.specialize(&a, &right_type),
(false, _) => return Err(InferError::MismatchedType),
};
Ok(out_type)
}
DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext),
// compiler bug if not present
DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()),
DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()),
DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()),
DeBrujinAst::Constant(_) => unreachable!(),
}
} }

@ -0,0 +1,59 @@
use std::{collections::HashMap, rc::Rc};
use crate::{
DeBrujinAst,
parse::Constant,
types::{Ident, PrimitiveType, TaggedType, Type},
vec_map::VecMap,
};
use super::InferError;
pub fn infer_type(
gamma: &HashMap<Ident, TaggedType>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast)
}
fn infer_type_debrujin_int(
gamma_free: &HashMap<Ident, TaggedType>,
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
match ast {
DeBrujinAst::Abstraction(_, arg_type, ast) => {
let gamma_ref = Rc::make_mut(&mut gamma_bound);
gamma_ref.map_keys(|i| *i += 1);
gamma_ref.insert(1, arg_type.clone().unwrap().into());
let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?;
// TODO: Fix this hack
let typ = arg_type.unwrap().make_arrow(out_type);
Ok(typ)
}
DeBrujinAst::Application(lhs, rhs) => {
let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?;
let Some((in_type, out_type)) = left_type.arrow() else {
return Err(InferError::NotAFunction);
};
let Some(right_type) =
infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete()
else {
return Err(InferError::ExpectedConreteType);
};
let out_type = match in_type.matches(&right_type) {
(true, None) => out_type,
(true, Some(a)) => out_type.specialize(&a, &right_type),
(false, _) => return Err(InferError::MismatchedType),
};
Ok(out_type)
}
DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext),
// compiler bug if not present
DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()),
DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()),
DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()),
DeBrujinAst::Constant(_) => unreachable!(),
}
}

@ -1,14 +1,21 @@
use std::{collections::HashMap, rc::Rc}; use std::{cell::RefCell, collections::HashMap, rc::Rc};
use crate::{ use crate::{
DeBrujinAst, DeBrujinAst,
inference::infer_type, inference::{
recursive::infer_type as infer_type_rec, unification::Constraints,
unification::infer_type as infer_type_uni,
},
multi_map::MultiMap,
parse::{Ast, Constant}, parse::{Ast, Constant},
types::{PrimitiveType, TaggedType, Type, TypeTag}, types::{PrimitiveType, TaggedType, Type, TypeTag},
vec_map::VecMap,
}; };
use super::unification::{TypeVarAst, TypeVarCtx};
#[test] #[test]
fn infer_add_nat() { fn infer_add_nat_rec() {
let ast: DeBrujinAst = Ast::Application( let ast: DeBrujinAst = Ast::Application(
Box::new(Ast::Application( Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())), Box::new(Ast::Variable("add".to_string())),
@ -25,16 +32,16 @@ fn infer_add_nat() {
TypeTag::Num, TypeTag::Num,
"a".to_string(), "a".to_string(),
Box::new(TaggedType::Concrete(Type::Arrow( Box::new(TaggedType::Concrete(Type::Arrow(
Box::new(Type::Generic("a".to_string())), Box::new(Type::TypeVariable("a".to_string())),
Box::new(Type::Arrow( Box::new(Type::Arrow(
Box::new(Type::Generic("a".to_string())), Box::new(Type::TypeVariable("a".to_string())),
Box::new(Type::Generic("a".to_string())), Box::new(Type::TypeVariable("a".to_string())),
)), )),
))), ))),
), ),
); );
let infered = infer_type(&gamma, ast).unwrap(); let infered = infer_type_rec(&gamma, ast).unwrap();
assert_eq!( assert_eq!(
infered, infered,
TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat)) TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat))
@ -42,7 +49,7 @@ fn infer_add_nat() {
} }
#[test] #[test]
fn infer_add_nat_partial() { fn infer_add_nat_partial_rec() {
let ast: DeBrujinAst = Ast::Application( let ast: DeBrujinAst = Ast::Application(
Box::new(Ast::Variable("add".to_string())), Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))), Box::new(Ast::Constant(Constant::Nat(5))),
@ -56,18 +63,163 @@ fn infer_add_nat_partial() {
TypeTag::Num, TypeTag::Num,
"a".to_string(), "a".to_string(),
Box::new(TaggedType::Concrete(Type::Arrow( Box::new(TaggedType::Concrete(Type::Arrow(
Box::new(Type::Generic("a".to_string())), Box::new(Type::TypeVariable("a".to_string())),
Box::new(Type::Arrow( Box::new(Type::Arrow(
Box::new(Type::Generic("a".to_string())), Box::new(Type::TypeVariable("a".to_string())),
Box::new(Type::Generic("a".to_string())), Box::new(Type::TypeVariable("a".to_string())),
)), )),
))), ))),
), ),
); );
let infered = infer_type(&gamma, ast).unwrap(); let infered = infer_type_rec(&gamma, ast).unwrap();
assert_eq!( assert_eq!(
infered, infered,
TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat))
); );
} }
#[test]
fn infer_add_nat_uni() {
let ast: DeBrujinAst = Ast::Application(
Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
Box::new(Ast::Constant(Constant::Nat(7))),
)
.into();
let mut gamma = HashMap::new();
gamma.insert(
"add".to_string(),
TaggedType::Concrete(Type::arrow(
PrimitiveType::Nat,
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
)),
);
let infered = infer_type_rec(&gamma, ast).unwrap();
assert_eq!(
infered,
TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat))
);
}
#[test]
fn infer_add_nat_partial_uni() {
let ast: DeBrujinAst = Ast::Application(
Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)
.into();
let mut gamma = HashMap::new();
gamma.insert(
"add".to_string(),
TaggedType::Concrete(Type::arrow(
PrimitiveType::Nat,
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
)),
);
let typ = infer_type_uni(&gamma, ast).unwrap();
assert_eq!(
typ,
TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat))
);
}
#[test]
fn infer_id_uni() {
let ast: DeBrujinAst = Ast::Abstraction(
"x".to_string(),
None,
Box::new(Ast::Variable("x".to_string())),
)
.into();
let mut gamma = HashMap::new();
let typ = infer_type_uni(&gamma, ast).unwrap();
assert_eq!(
typ,
TaggedType::Tagged(
TypeTag::Any,
"?typ_1".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("?typ_1", "?typ_1")))
)
);
}
#[test]
fn subst_type_var() {
let typ = TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("a", "?typ_4"))),
);
let subst = TaggedType::Tagged(
TypeTag::Any,
"b".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("b", "b"))),
);
let typ = typ.subst_typevar("?typ_4", 0, subst).unwrap();
assert_eq!(
typ,
TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Tagged(
TypeTag::Any,
"b".to_string(),
Box::new(TaggedType::Concrete(Type::arrow(
"a",
Type::arrow("b", "b")
)))
))
)
);
let typ = TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("a", "?typ_4"))),
);
let subst = TaggedType::Tagged(
TypeTag::Any,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("a", "a"))),
);
let typ = typ.subst_typevar("?typ_4", 0, subst).unwrap();
assert!(match typ {
TaggedType::Tagged(TypeTag::Num, a, tagged_type) => match *tagged_type {
TaggedType::Tagged(TypeTag::Any, b, tagged_type) => match *tagged_type {
TaggedType::Concrete(typ) => match typ {
Type::Arrow(lhs, rhs) => match (*lhs, *rhs) {
(Type::TypeVariable(a1), Type::Arrow(b1, b2)) if a1 == a => {
match (*b1, *b2) {
(Type::TypeVariable(b1), Type::TypeVariable(b2)) => {
(b1 == b) && (b2 == b)
}
_ => false,
}
}
_ => false,
},
_ => false,
},
_ => false,
},
_ => false,
},
_ => false,
});
}

@ -0,0 +1,329 @@
use std::{cell::RefCell, clone, collections::HashMap, rc::Rc};
use crate::{
Ast, DeBrujinAst,
multi_map::MultiMap,
parse::Constant,
types::{Ident, PrimitiveType, TaggedType, Type, TypeTag},
vec_map::VecMap,
};
use super::InferError;
type TypeVar = TaggedType;
pub struct TypeVarCtx {
counter: RefCell<usize>,
}
impl TypeVarCtx {
pub fn new() -> Self {
Self {
counter: RefCell::new(0),
}
}
pub fn get_var(&self) -> TaggedType {
let mut num = self.counter.borrow_mut();
let res = format!("?typ_{num}");
*num += 1;
Type::TypeVariable(res).into()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TypeVarAst {
Abstraction(TypeVar, Ident, TaggedType, Box<TypeVarAst>), // \0:1.2
Application(TypeVar, Box<TypeVarAst>, Box<TypeVarAst>), // 0 1
FreeVariable(TypeVar, Ident), // x
BoundVariable(TypeVar, usize), // 1
Constant(TypeVar, Constant), // true | false | n
}
pub type Constraints = MultiMap<TypeVar, TaggedType>;
pub(super) fn step_1(
ast: DeBrujinAst,
gamma_free: &HashMap<Ident, TaggedType>,
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
constraints: Rc<RefCell<Constraints>>,
ctx: &TypeVarCtx,
) -> Result<(TypeVarAst, TypeVar), InferError> {
match ast {
DeBrujinAst::Abstraction(i, Some(typ), ast) => {
let var = ctx.get_var();
let gamma_ref = Rc::make_mut(&mut gamma_bound);
gamma_ref.map_keys(|i| *i += 1);
gamma_ref.insert(1, typ.clone());
let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?;
// RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var));
Ok((
TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)),
var,
))
}
DeBrujinAst::Abstraction(i, None, ast) => {
let var = ctx.get_var();
let typ = ctx.get_var();
let gamma_ref = Rc::make_mut(&mut gamma_bound);
gamma_ref.map_keys(|i| *i += 1);
gamma_ref.insert(1, typ.clone());
let (ast, rhs_var) = step_1(*ast, gamma_free, gamma_bound, constraints.clone(), ctx)?;
// RefCell::borrow_mut(&constraints).insert(var.clone(), typ.clone().make_arrow(rhs_var));
Ok((
TypeVarAst::Abstraction(var.clone(), i, typ, Box::new(ast)),
var,
))
}
DeBrujinAst::Application(lhs, rhs) => {
let var = ctx.get_var();
let (lhs, lhs_var) = step_1(
*lhs,
gamma_free,
gamma_bound.clone(),
constraints.clone(),
ctx,
)?;
let (rhs, rhs_var) = step_1(
*rhs,
gamma_free,
gamma_bound.clone(),
constraints.clone(),
ctx,
)?;
RefCell::borrow_mut(&constraints).insert(lhs_var, rhs_var.make_arrow(var.clone()));
Ok((
TypeVarAst::Application(var.clone(), Box::new(lhs), Box::new(rhs)),
var,
))
}
DeBrujinAst::FreeVariable(v) => {
let var = ctx.get_var();
let typ = gamma_free
.get(&v)
.cloned()
.ok_or(InferError::NotInContext)?;
RefCell::borrow_mut(&constraints).insert(var.clone(), typ);
Ok((TypeVarAst::FreeVariable(var.clone(), v), var))
}
DeBrujinAst::BoundVariable(i) => {
let var = ctx.get_var();
let typ = gamma_bound
.get(&i)
.cloned()
.ok_or(InferError::NotInContext)?;
RefCell::borrow_mut(&constraints).insert(var.clone(), typ);
Ok((TypeVarAst::BoundVariable(var.clone(), i), var))
}
DeBrujinAst::Constant(constant) => {
let var = ctx.get_var();
let typ = match constant {
Constant::Nat(_) => Type::Primitive(PrimitiveType::Nat),
Constant::Float(_) => Type::Primitive(PrimitiveType::Float),
Constant::Bool(_) => Type::Primitive(PrimitiveType::Bool),
}
.into();
RefCell::borrow_mut(&constraints).insert(var.clone(), typ);
Ok((TypeVarAst::Constant(var.clone(), constant), var))
}
}
}
pub(super) fn step_2(
mut constraints: Constraints,
) -> Result<Option<Box<dyn FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError>>>, InferError> {
if let Some((s, t)) = constraints.pop() {
if s == t {
step_2(constraints)
} else if s.type_var().is_some_and(|x| !t.name_used(&x)) {
let Some(x) = s.type_var() else {
unreachable!()
};
constraints.try_map(|(k, v)| {
Ok((
k.subst_typevar(&x, 0, t.clone())?,
v.subst_typevar(&x, 0, t.clone())?,
))
})?;
let subst = step_2(constraints)?;
Ok(Some(subst_comp(x, t, subst)))
} else if t.type_var().is_some_and(|x| !s.name_used(&x)) {
let Some(x) = t.type_var() else {
unreachable!()
};
constraints.try_map(|(k, v)| {
Ok((
k.subst_typevar(&x, 0, s.clone())?,
v.subst_typevar(&x, 0, s.clone())?,
))
})?;
let subst = step_2(constraints)?;
Ok(Some(subst_comp(x, t, subst)))
} else if let (Some((s_lhs, s_rhs)), Some((t_lhs, t_rhs))) = (s.arrow(), t.arrow()) {
constraints.insert(s_lhs, t_lhs);
constraints.insert(s_rhs, t_rhs);
step_2(constraints)
} else {
panic!()
}
} else {
Ok(None)
}
}
impl TypeVarAst {
pub fn subst(self, var: &str, subst: TaggedType) -> Result<TypeVarAst, InferError> {
match self {
TypeVarAst::Abstraction(tagged_type1, ident, tagged_type2, ast) => {
Ok(TypeVarAst::Abstraction(
tagged_type1.subst_typevar(var, 0, subst.clone())?,
ident,
tagged_type2.subst_typevar(var, 0, subst.clone())?,
Box::new(ast.subst(var, subst)?),
))
}
TypeVarAst::Application(tagged_type, lhs, rhs) => Ok(TypeVarAst::Application(
tagged_type.subst_typevar(var, 0, subst.clone())?,
Box::new(lhs.subst(var, subst.clone())?),
Box::new(rhs.subst(var, subst)?),
)),
TypeVarAst::FreeVariable(tagged_type, x) => Ok(TypeVarAst::FreeVariable(
tagged_type.subst_typevar(var, 0, subst)?,
x,
)),
TypeVarAst::BoundVariable(tagged_type, i) => Ok(TypeVarAst::BoundVariable(
tagged_type.subst_typevar(var, 0, subst)?,
i,
)),
TypeVarAst::Constant(tagged_type, constant) => Ok(TypeVarAst::Constant(
tagged_type.subst_typevar(var, 0, subst)?,
constant,
)),
}
}
}
pub fn subst_comp<'a, F>(
var: String,
subst: TaggedType,
then: Option<F>,
) -> Box<dyn FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError> + 'a>
where
F: FnOnce(TypeVarAst) -> Result<TypeVarAst, InferError> + 'a,
{
Box::new(move |ast: TypeVarAst| -> Result<TypeVarAst, InferError> {
let ast = ast.subst(&var, subst)?;
if let Some(f) = then { f(ast) } else { Ok(ast) }
})
}
pub fn infer_type(
gamma: &HashMap<Ident, TaggedType>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
let gamma_bound = Rc::new(VecMap::new());
let constraints = Rc::new(RefCell::new(MultiMap::new()));
let ctx = TypeVarCtx::new();
let (ast, _) = step_1(ast, gamma, gamma_bound, constraints.clone(), &ctx)?;
constraints.clone();
let res = step_2(constraints.take())?.unwrap();
let ast = res(ast)?;
fn get_type(ast: TypeVarAst) -> TaggedType {
match ast {
TypeVarAst::Abstraction(_, _, typ, ast) => typ.make_arrow(get_type(*ast)),
TypeVarAst::Application(tagged_type, _, _) => tagged_type,
TypeVarAst::FreeVariable(tagged_type, _) => tagged_type,
TypeVarAst::BoundVariable(tagged_type, _) => tagged_type,
TypeVarAst::Constant(tagged_type, constant) => tagged_type,
}
}
let typ = get_type(ast);
let mut typ = typ;
for free_var in typ.free_vars() {
typ = TaggedType::Tagged(TypeTag::Any, free_var, Box::new(typ));
}
Ok(typ.normalise())
}
impl TaggedType {
pub fn subst_typevar(
self,
var: &str,
diversifier: usize,
mut subst: TaggedType,
) -> Result<TaggedType, InferError> {
match self {
TaggedType::Tagged(t, i, tagged_type) => {
if subst.name_used(&i) {
subst.map_name(|f| {
if *f == i {
*f = format!("{i}+{diversifier}")
}
});
}
Ok(TaggedType::Tagged(
t,
i,
Box::new(tagged_type.subst_typevar(var, diversifier + 1, subst)?),
))
}
TaggedType::Concrete(t) => subst.subst_into(var, diversifier, t),
}
}
fn subst_into(
self,
var: &str,
diversifier: usize,
mut target: Type,
) -> Result<TaggedType, InferError> {
match self {
TaggedType::Tagged(t, i, tagged_type) => {
if target.name_used(&i) {
target.map_name(&|f| {
if *f == i {
*f = format!("{i}+{diversifier}")
}
});
}
Ok(TaggedType::Tagged(
t,
i,
Box::new(tagged_type.subst_into(var, diversifier + 1, target)?),
))
}
TaggedType::Concrete(c) => Ok(TaggedType::Concrete(target.subst_typevar(var, &c)?)),
}
}
}
impl Type {
pub fn subst_typevar(self, var: &str, subst: &Type) -> Result<Type, InferError> {
match self {
Type::TypeVariable(v) if v == var => Ok(subst.clone()),
Type::Arrow(lhs, rhs) => Ok(Type::Arrow(
Box::new(lhs.subst_typevar(var, subst)?),
Box::new(rhs.subst_typevar(var, subst)?),
)),
t => Ok(t),
}
}
}

@ -4,11 +4,13 @@ use std::fmt::Display;
mod exec; mod exec;
mod inference; mod inference;
mod multi_map;
mod parse; mod parse;
mod types; mod types;
mod vec_map; mod vec_map;
pub use exec::{Builtin, DeBrujinAst, builtin}; pub use exec::{Builtin, DeBrujinAst, builtin};
pub use inference::infer_type; pub use inference::recursive::infer_type as infer_type_rec;
pub use inference::unification::infer_type as infer_type_uni;
use lalrpop_util::lalrpop_mod; use lalrpop_util::lalrpop_mod;
pub use parse::{Ast, parse_ast_str, parse_type_str}; pub use parse::{Ast, parse_ast_str, parse_type_str};

@ -5,7 +5,7 @@ use std::{
}; };
use stlc_type_inference::{ use stlc_type_inference::{
Ast, Builtin, DeBrujinAst, builtin, infer_type, parse_ast_str, parse_type_str, Ast, Builtin, DeBrujinAst, builtin, infer_type_uni, parse_ast_str, parse_type_str,
}; };
macro_rules! repl_err { macro_rules! repl_err {
@ -35,7 +35,7 @@ fn main() {
if let Some((cmd, expr)) = tl.split_once(' ') { if let Some((cmd, expr)) = tl.split_once(' ') {
match cmd { match cmd {
"t" => match parse_ast_str(expr) { "t" => match parse_ast_str(expr) {
Ok(a) => match infer_type(&gamma, a.into()) { Ok(a) => match infer_type_uni(&gamma, a.into()) {
Ok(t) => println!("{t}"), Ok(t) => println!("{t}"),
Err(e) => repl_err!("Could not infer type {e:?}"), Err(e) => repl_err!("Could not infer type {e:?}"),
}, },
@ -77,7 +77,7 @@ fn main() {
}; };
let ast: DeBrujinAst = ast.into(); let ast: DeBrujinAst = ast.into();
let typ = match infer_type(&gamma, ast.clone()) { let typ = match infer_type_uni(&gamma, ast.clone()) {
Ok(t) => t, Ok(t) => t,
Err(e) => repl_err!("Could not infer type {e:?}"), Err(e) => repl_err!("Could not infer type {e:?}"),
}; };

@ -0,0 +1,86 @@
use std::{fmt::Debug, mem};
pub struct MultiMap<K: Eq, V> {
map: Vec<(K, V)>,
}
impl<K: Eq, V> MultiMap<K, V> {
pub fn new() -> Self {
Self { map: Vec::new() }
}
pub fn insert(&mut self, key: K, val: V) {
self.map.push((key, val));
}
pub fn get(&self, key: &K) -> Vec<&V> {
self.map
.iter()
.filter_map(|(k, v)| if k == key { Some(v) } else { None })
.collect()
}
pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) {
self.map.iter_mut().for_each(|(k, _)| f(k))
}
pub fn map<F: Fn(&mut (K, V))>(&mut self, f: F) {
self.map.iter_mut().for_each(f);
}
pub fn try_map<E, F: Fn((K, V)) -> Result<(K, V), E>>(&mut self, f: F) -> Result<(), E> {
let vec = mem::take(&mut self.map);
let vec = vec.into_iter().map(f).collect::<Result<Vec<(K, V)>, _>>()?;
self.map = vec;
Ok(())
}
pub fn find<F: FnMut(&&(K, V)) -> bool>(&self, f: F) -> Option<&(K, V)> {
self.map.iter().find(f)
}
pub fn find_remove<F: FnMut(&&(K, V)) -> bool>(&mut self, mut f: F) -> Option<(K, V)> {
let idx = self
.map
.iter()
.enumerate()
.find_map(|(n, ref kv)| if f(kv) { Some(n) } else { None })?;
Some(self.map.swap_remove(idx))
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn pop(&mut self) -> Option<(K, V)> {
self.map.pop()
}
}
impl<K: Eq, V> Default for MultiMap<K, V> {
fn default() -> Self {
Self {
map: Vec::default(),
}
}
}
impl<K: Eq + Clone, V: Clone> Clone for MultiMap<K, V> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
impl<K: Eq + Debug, V: Debug> Debug for MultiMap<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_map()
.entries(self.map.iter().map(|&(ref k, ref v)| (k, v)))
.finish()
}
}

@ -40,10 +40,10 @@ pub enum Constant {
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum Ast { pub enum Ast {
Abstraction(Ident, TaggedType, Box<Ast>), // \0:1.2 Abstraction(Ident, Option<TaggedType>, Box<Ast>), // \0:1.2
Application(Box<Ast>, Box<Ast>), // 0 1 Application(Box<Ast>, Box<Ast>), // 0 1
Variable(Ident), // x Variable(Ident), // x
Constant(Constant), // true | false | n Constant(Constant), // true | false | n
} }
pub fn parse_ast_str(src: &str) -> Result<Ast, LALRPopError<usize, tokenize::Token, ParseError>> { pub fn parse_ast_str(src: &str) -> Result<Ast, LALRPopError<usize, tokenize::Token, ParseError>> {
@ -84,7 +84,8 @@ impl Type {
impl Display for Ast { impl Display for Ast {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Ast::Abstraction(var, typ, ast) => write!(f, "(\\{var}:{typ}.{ast})"), Ast::Abstraction(var, Some(typ), ast) => write!(f, "(\\{var}:{typ}.{ast})"),
Ast::Abstraction(var, None, ast) => write!(f, "(\\{var}.{ast})"),
Ast::Application(lhs, rhs) => write!(f, "{lhs} {rhs}"), Ast::Application(lhs, rhs) => write!(f, "{lhs} {rhs}"),
Ast::Variable(v) => write!(f, "{v}"), Ast::Variable(v) => write!(f, "{v}"),
Ast::Constant(constant) => write!(f, "{constant}"), Ast::Constant(constant) => write!(f, "{constant}"),

@ -93,7 +93,7 @@ fn parse_abstraction() {
ast, ast,
Ast::Abstraction( Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(Ast::Variable("x".to_string())) Box::new(Ast::Variable("x".to_string()))
) )
); );
@ -103,10 +103,10 @@ fn parse_abstraction() {
ast, ast,
Ast::Abstraction( Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"y".to_string(), "y".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(Ast::Variable("x".to_string())) Box::new(Ast::Variable("x".to_string()))
)) ))
) )
@ -119,18 +119,18 @@ fn parse_abstraction() {
ast, ast,
Ast::Abstraction( Ast::Abstraction(
"x".to_string(), "x".to_string(),
TaggedType::Tagged( Some(TaggedType::Tagged(
TypeTag::Any, TypeTag::Any,
"a".to_string(), "a".to_string(),
Box::new(TaggedType::Concrete("a".into())) Box::new(TaggedType::Concrete("a".into()))
), )),
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"y".to_string(), "y".to_string(),
TaggedType::Tagged( Some(TaggedType::Tagged(
TypeTag::Any, TypeTag::Any,
"b".to_string(), "b".to_string(),
Box::new(TaggedType::Concrete("b".into())) Box::new(TaggedType::Concrete("b".into()))
), )),
Box::new(Ast::Variable("x".to_string())) Box::new(Ast::Variable("x".to_string()))
)) ))
) )
@ -146,7 +146,7 @@ fn parse_application() {
Ast::Application( Ast::Application(
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"x".to_string(), "x".to_string(),
Type::Primitive(PrimitiveType::Nat).into(), Some(Type::Primitive(PrimitiveType::Nat).into()),
Box::new(Ast::Variable("x".to_string())) Box::new(Ast::Variable("x".to_string()))
)), )),
Box::new(Ast::Constant(Constant::Nat(5))) Box::new(Ast::Constant(Constant::Nat(5)))
@ -212,18 +212,18 @@ fn parse_application() {
Box::new(Ast::Application( Box::new(Ast::Application(
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"x".to_string(), "x".to_string(),
TaggedType::Tagged( Some(TaggedType::Tagged(
TypeTag::Any, TypeTag::Any,
"a".to_string(), "a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("a", "a"))) Box::new(TaggedType::Concrete(Type::arrow("a", "a")))
), )),
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"y".to_string(), "y".to_string(),
TaggedType::Tagged( Some(TaggedType::Tagged(
TypeTag::Num, TypeTag::Num,
"a".to_string(), "a".to_string(),
Box::new(TaggedType::Concrete("a".into())) Box::new(TaggedType::Concrete("a".into()))
), )),
Box::new(Ast::Application( Box::new(Ast::Application(
Box::new(Ast::Application( Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())), Box::new(Ast::Variable("add".to_string())),
@ -238,11 +238,11 @@ fn parse_application() {
)), )),
Box::new(Ast::Abstraction( Box::new(Ast::Abstraction(
"x".to_string(), "x".to_string(),
TaggedType::Tagged( Some(TaggedType::Tagged(
TypeTag::Any, TypeTag::Any,
"a".to_string(), "a".to_string(),
Box::new(TaggedType::Concrete("a".into())) Box::new(TaggedType::Concrete("a".into()))
), )),
Box::new(Ast::Variable("x".to_string())) Box::new(Ast::Variable("x".to_string()))
)) ))
)), )),

@ -45,7 +45,8 @@ extern {
pub Ast: Ast = { pub Ast: Ast = {
Term => <>, Term => <>,
r"\" <x:Ident> ":" <t:TaggedType> "." <ast:Ast> => Ast::Abstraction(x, t, Box::new(ast)), r"\" <x:Ident> ":" <t:TaggedType> "." <ast:Ast> => Ast::Abstraction(x, Some(t), Box::new(ast)),
r"\" <x:Ident> "." <ast:Ast> => Ast::Abstraction(x, None, Box::new(ast)),
}; };

@ -1,29 +1,35 @@
use std::{fmt::Display, str}; use std::{
collections::{HashMap, HashSet},
fmt::Display,
mem,
rc::Rc,
str,
};
pub type Ident = String; pub type Ident = String;
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PrimitiveType { pub enum PrimitiveType {
Nat, Nat,
Bool, Bool,
Float, Float,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TypeTag { pub enum TypeTag {
Num, Num,
Any, Any,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TaggedType { pub enum TaggedType {
Tagged(TypeTag, Ident, Box<TaggedType>), Tagged(TypeTag, Ident, Box<TaggedType>),
Concrete(Type), Concrete(Type),
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Type { pub enum Type {
Generic(Ident), // a TypeVariable(Ident), // a
Primitive(PrimitiveType), // Bool | Nat Primitive(PrimitiveType), // Bool | Nat
Arrow(Box<Type>, Box<Type>), // 0 -> 1 Arrow(Box<Type>, Box<Type>), // 0 -> 1
} }
@ -34,6 +40,18 @@ impl From<Type> for TaggedType {
} }
} }
impl TypeTag {
// If one tag tightens the other return that tag, otherwise none,
fn tightens(self, other: Self) -> Option<Self> {
match (self, other) {
(TypeTag::Num, TypeTag::Num) => Some(Self::Num),
(TypeTag::Num, TypeTag::Any) => Some(Self::Num),
(TypeTag::Any, TypeTag::Num) => Some(Self::Num),
(TypeTag::Any, TypeTag::Any) => Some(Self::Any),
}
}
}
impl TaggedType { impl TaggedType {
pub fn r#type(&self) -> &Type { pub fn r#type(&self) -> &Type {
match self { match self {
@ -52,13 +70,7 @@ impl TaggedType {
pub fn to_concrete(self) -> Option<Type> { pub fn to_concrete(self) -> Option<Type> {
match self { match self {
TaggedType::Tagged(type_tag, _, tagged_type) => None, TaggedType::Tagged(type_tag, _, tagged_type) => None,
TaggedType::Concrete(t) => { TaggedType::Concrete(t) => Some(t),
if t.is_concrete() {
Some(t)
} else {
None
}
}
} }
} }
@ -92,12 +104,21 @@ impl TaggedType {
} }
} }
pub fn make_arrow(self, rhs: Type) -> Self { pub fn make_arrow(self, rhs: TaggedType) -> Self {
match self { match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => { TaggedType::Tagged(type_tag, ident, tagged_type) => {
TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow(rhs))) TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow(rhs)))
} }
TaggedType::Concrete(t) => Type::arrow(t, rhs).into(), TaggedType::Concrete(t) => rhs.make_arrow_lhs(t),
}
}
fn make_arrow_lhs(self, lhs: Type) -> Self {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow_lhs(lhs)))
}
TaggedType::Concrete(t) => Type::arrow(lhs, t).into(),
} }
} }
@ -113,20 +134,85 @@ impl TaggedType {
} }
} }
fn name_used(&self, ident: &str) -> bool { pub fn map_name<F: Fn(&mut String)>(&mut self, f: F) {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
f(ident);
tagged_type.map_name(f);
}
TaggedType::Concrete(t) => t.map_name(&f),
}
}
pub fn type_var(&self) -> Option<String> {
match self {
TaggedType::Tagged(type_tag, _, tagged_type) => tagged_type.type_var(),
TaggedType::Concrete(c) => c.type_var(),
}
}
pub fn free_vars(&self) -> HashSet<String> {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
let mut vars = tagged_type.free_vars();
vars.retain(|v| v != ident);
vars
}
TaggedType::Concrete(c) => c.type_vars(),
}
}
pub fn name_used(&self, ident: &str) -> bool {
match self { match self {
TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident), TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident),
TaggedType::Concrete(c) => c.name_used(ident), TaggedType::Concrete(c) => c.name_used(ident),
} }
} }
pub fn is_concrete(&self) -> bool {
match self {
TaggedType::Tagged(type_tag, _, tagged_type) => tagged_type.is_concrete(),
TaggedType::Concrete(c) => c.is_concrete(),
}
}
pub fn normalise(self) -> TaggedType {
let only_used = self.clear_unused_names();
fn dedup(this: TaggedType, mut used: Rc<HashMap<String, TypeTag>>) -> TaggedType {
match this {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
if !used.contains_key(&ident) {
Rc::make_mut(&mut used).insert(ident, type_tag);
dedup(*tagged_type, used)
} else {
let Some(tag) = TypeTag::tightens(
Rc::make_mut(&mut used).remove(&ident).unwrap(),
type_tag,
) else {
todo!()
};
Rc::make_mut(&mut used).insert(ident.clone(), tag.clone());
dedup(*tagged_type, used)
}
}
TaggedType::Concrete(c) => mem::take(Rc::make_mut(&mut used))
.into_iter()
.fold(TaggedType::Concrete(c), |acc, (i, t)| {
TaggedType::Tagged(t, i, Box::new(acc))
}),
}
}
dedup(only_used, Rc::new(HashMap::new()))
}
fn clear_unused_names(self) -> TaggedType { fn clear_unused_names(self) -> TaggedType {
match self { match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => { TaggedType::Tagged(type_tag, ident, tagged_type) => {
if tagged_type.name_used(&ident) { if tagged_type.name_used(&ident) {
TaggedType::Tagged(type_tag, ident, tagged_type) TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.clear_unused_names()))
} else { } else {
*tagged_type tagged_type.clear_unused_names()
} }
} }
t => t, t => t,
@ -142,7 +228,7 @@ impl From<PrimitiveType> for Type {
impl From<&str> for Type { impl From<&str> for Type {
fn from(value: &str) -> Self { fn from(value: &str) -> Self {
Self::Generic(value.to_string()) Self::TypeVariable(value.to_string())
} }
} }
@ -153,7 +239,7 @@ impl Type {
pub fn is_concrete(&self) -> bool { pub fn is_concrete(&self) -> bool {
match self { match self {
Type::Generic(_) => false, Type::TypeVariable(_) => false,
Type::Primitive(primitive_type) => true, Type::Primitive(primitive_type) => true,
Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(), Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(),
} }
@ -161,7 +247,7 @@ impl Type {
pub fn has_tag(&self, tag: &TypeTag) -> bool { pub fn has_tag(&self, tag: &TypeTag) -> bool {
match self { match self {
Type::Generic(_) => false, Type::TypeVariable(_) => false,
Type::Primitive(primitive_type) => match (primitive_type, tag) { Type::Primitive(primitive_type) => match (primitive_type, tag) {
(_, TypeTag::Any) => true, (_, TypeTag::Any) => true,
(PrimitiveType::Nat, TypeTag::Num) => true, (PrimitiveType::Nat, TypeTag::Num) => true,
@ -171,17 +257,54 @@ impl Type {
} }
} }
fn name_used(&self, ident: &str) -> bool { pub fn map_name<F: Fn(&mut String)>(&mut self, f: &F) {
match self {
Type::TypeVariable(v) => f(v),
Type::Arrow(lhs, rhs) => {
lhs.map_name(f);
rhs.map_name(f);
}
_ => {}
}
}
pub fn name_used(&self, ident: &str) -> bool {
match self { match self {
Type::Generic(i) if *i == ident => true, Type::TypeVariable(i) if *i == ident => true,
Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident), Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident),
_ => false, _ => false,
} }
} }
pub fn type_var(&self) -> Option<String> {
match self {
Type::TypeVariable(v) => Some(v.clone()),
Type::Primitive(primitive_type) => None,
Type::Arrow(_, _) => None,
}
}
pub fn type_vars(&self) -> HashSet<String> {
match self {
Type::TypeVariable(v) => {
let mut set = HashSet::new();
set.insert(v.to_string());
set
}
Type::Primitive(primitive_type) => HashSet::new(),
Type::Arrow(lhs, rhs) => {
let mut vars = lhs.type_vars();
for var in rhs.type_vars().into_iter() {
vars.insert(var);
}
vars
}
}
}
fn specialize(self, ident: &str, typ: &Type) -> Type { fn specialize(self, ident: &str, typ: &Type) -> Type {
match self { match self {
Type::Generic(i) if i == ident => typ.clone(), Type::TypeVariable(i) if i == ident => typ.clone(),
Type::Arrow(lhs, rhs) => Type::Arrow( Type::Arrow(lhs, rhs) => Type::Arrow(
Box::new(lhs.specialize(ident, typ)), Box::new(lhs.specialize(ident, typ)),
Box::new(rhs.specialize(ident, typ)), Box::new(rhs.specialize(ident, typ)),
@ -214,7 +337,7 @@ impl Display for TaggedType {
impl Display for Type { impl Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Type::Generic(i) => write!(f, "{i}"), Type::TypeVariable(i) => write!(f, "{i}"),
Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), Type::Primitive(primitive_type) => write!(f, "{primitive_type}"),
Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"),
} }

@ -29,6 +29,14 @@ impl<K: Eq, V> VecMap<K, V> {
} }
} }
impl<K: Eq, V> Default for VecMap<K, V> {
fn default() -> Self {
Self {
map: Vec::default(),
}
}
}
impl<K: Eq + Clone, V: Clone> Clone for VecMap<K, V> { impl<K: Eq + Clone, V: Clone> Clone for VecMap<K, V> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {

Loading…
Cancel
Save