parent
134f803b88
commit
6a46c6ca52
@ -1,37 +1,75 @@
|
||||
use std::{collections::HashMap, error::Error, rc::Rc};
|
||||
use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc};
|
||||
|
||||
use crate::{Ast, Constant, Ident, PrimitiveType, Type};
|
||||
use crate::{
|
||||
DeBrujinAst, Ident, PrimitiveType, Type,
|
||||
parse::{Ast, Constant},
|
||||
types::{TaggedType, TypeTag},
|
||||
vec_map::VecMap,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum InferError {
|
||||
NotAFunction,
|
||||
MismatchedType,
|
||||
NotInContext,
|
||||
ExpectedConreteType,
|
||||
ExpectedTypeWithTag,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub fn infer_type(mut gamma: Rc<HashMap<Ident, Type>>, ast: Ast) -> Result<Type, InferError> {
|
||||
pub fn infer_type(
|
||||
gamma: &HashMap<Ident, TaggedType>,
|
||||
ast: DeBrujinAst,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast)
|
||||
}
|
||||
|
||||
fn infer_type_debrujin_int(
|
||||
gamma_free: &HashMap<Ident, TaggedType>,
|
||||
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
|
||||
ast: DeBrujinAst,
|
||||
) -> Result<TaggedType, InferError> {
|
||||
match ast {
|
||||
Ast::Abstraction(arg, arg_type, ast) => {
|
||||
Rc::make_mut(&mut gamma).insert(arg, arg_type.clone());
|
||||
let out_type = infer_type(gamma, *ast)?;
|
||||
Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type)))
|
||||
DeBrujinAst::Abstraction(_, arg_type, ast) => {
|
||||
let gamma_ref = Rc::make_mut(&mut gamma_bound);
|
||||
gamma_ref.map_keys(|i| *i += 1);
|
||||
gamma_ref.insert(1, arg_type.clone().into());
|
||||
|
||||
let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?;
|
||||
let typ = Type::Arrow(
|
||||
Box::new(arg_type),
|
||||
Box::new(
|
||||
out_type
|
||||
.to_concrete()
|
||||
.ok_or(InferError::ExpectedConreteType)?,
|
||||
),
|
||||
);
|
||||
Ok(typ.into())
|
||||
}
|
||||
Ast::Application(left, right) => {
|
||||
let left_type = infer_type(gamma.clone(), *left)?;
|
||||
let Type::Arrow(in_type, out_type) = left_type else {
|
||||
DeBrujinAst::Application(lhs, rhs) => {
|
||||
let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?;
|
||||
let Some((in_type, out_type)) = left_type.arrow() else {
|
||||
return Err(InferError::NotAFunction);
|
||||
};
|
||||
let right_type = infer_type(gamma, *right)?;
|
||||
if *in_type != right_type {
|
||||
return Err(InferError::MismatchedType);
|
||||
}
|
||||
Ok(*out_type)
|
||||
let Some(right_type) =
|
||||
infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete()
|
||||
else {
|
||||
return Err(InferError::ExpectedConreteType);
|
||||
};
|
||||
let out_type = match in_type.matches(&right_type) {
|
||||
(true, None) => out_type,
|
||||
(true, Some(a)) => out_type.specialize(&a, &right_type),
|
||||
(false, _) => return Err(InferError::MismatchedType),
|
||||
};
|
||||
Ok(out_type)
|
||||
}
|
||||
Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext),
|
||||
Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)),
|
||||
Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)),
|
||||
DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext),
|
||||
// compiler bug if not present
|
||||
DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()),
|
||||
DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()),
|
||||
DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()),
|
||||
DeBrujinAst::Constant(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -1,49 +1,73 @@
|
||||
use std::{collections::HashMap, rc::Rc};
|
||||
|
||||
use crate::{Ast, Constant, PrimitiveType, Type};
|
||||
|
||||
use super::infer_type;
|
||||
use crate::{
|
||||
DeBrujinAst, PrimitiveType, Type,
|
||||
inference::infer_type,
|
||||
parse::{Ast, Constant},
|
||||
types::{TaggedType, TypeTag},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn infer_id_type() {
|
||||
let ast = Ast::Abstraction(
|
||||
"x".to_string(),
|
||||
Type::Primitive(PrimitiveType::Nat),
|
||||
Box::new(Ast::Variable("x".to_string())),
|
||||
fn infer_add_nat() {
|
||||
let ast: DeBrujinAst = Ast::Application(
|
||||
Box::new(Ast::Application(
|
||||
Box::new(Ast::Variable("add".to_string())),
|
||||
Box::new(Ast::Constant(Constant::Nat(5))),
|
||||
)),
|
||||
Box::new(Ast::Constant(Constant::Nat(7))),
|
||||
)
|
||||
.into();
|
||||
|
||||
let mut gamma = HashMap::new();
|
||||
gamma.insert(
|
||||
"add".to_string(),
|
||||
crate::types::TaggedType::Tagged(
|
||||
TypeTag::Num,
|
||||
"a".to_string(),
|
||||
Box::new(TaggedType::Concrete(Type::Arrow(
|
||||
Box::new(Type::Generic("a".to_string())),
|
||||
Box::new(Type::Arrow(
|
||||
Box::new(Type::Generic("a".to_string())),
|
||||
Box::new(Type::Generic("a".to_string())),
|
||||
)),
|
||||
))),
|
||||
),
|
||||
);
|
||||
|
||||
let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap();
|
||||
let infered = infer_type(&gamma, ast).unwrap();
|
||||
assert_eq!(
|
||||
infered,
|
||||
Type::Arrow(
|
||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
||||
Box::new(Type::Primitive(PrimitiveType::Nat))
|
||||
)
|
||||
)
|
||||
TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_addition_result_type() {
|
||||
let ast = Ast::Application(
|
||||
Box::new(Ast::Application(
|
||||
Box::new(Ast::Variable("add".to_string())),
|
||||
Box::new(Ast::Constant(Constant::Nat(5))),
|
||||
)),
|
||||
Box::new(Ast::Constant(Constant::Nat(7))),
|
||||
);
|
||||
fn infer_add_nat_partial() {
|
||||
let ast: DeBrujinAst = Ast::Application(
|
||||
Box::new(Ast::Variable("add".to_string())),
|
||||
Box::new(Ast::Constant(Constant::Nat(5))),
|
||||
)
|
||||
.into();
|
||||
|
||||
let mut gamma = HashMap::new();
|
||||
gamma.insert(
|
||||
"add".to_string(),
|
||||
Type::Arrow(
|
||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
||||
Box::new(Type::Arrow(
|
||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
||||
)),
|
||||
crate::types::TaggedType::Tagged(
|
||||
TypeTag::Num,
|
||||
"a".to_string(),
|
||||
Box::new(TaggedType::Concrete(Type::Arrow(
|
||||
Box::new(Type::Generic("a".to_string())),
|
||||
Box::new(Type::Arrow(
|
||||
Box::new(Type::Generic("a".to_string())),
|
||||
Box::new(Type::Generic("a".to_string())),
|
||||
)),
|
||||
))),
|
||||
),
|
||||
);
|
||||
|
||||
let infered = infer_type(Rc::new(gamma), ast).unwrap();
|
||||
assert_eq!(infered, Type::Primitive(PrimitiveType::Nat));
|
||||
let infered = infer_type(&gamma, ast).unwrap();
|
||||
assert_eq!(
|
||||
infered,
|
||||
TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat))
|
||||
);
|
||||
}
|
||||
|
@ -0,0 +1,211 @@
|
||||
use std::{fmt::Display, str};
|
||||
|
||||
pub type Ident = String;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PrimitiveType {
|
||||
Nat,
|
||||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TypeTag {
|
||||
Num,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TaggedType {
|
||||
Tagged(TypeTag, Ident, Box<TaggedType>),
|
||||
Concrete(Type),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Type {
|
||||
Generic(Ident), // a
|
||||
Primitive(PrimitiveType), // Bool | Nat
|
||||
Arrow(Box<Type>, Box<Type>), // 0 -> 1
|
||||
}
|
||||
|
||||
impl From<Type> for TaggedType {
|
||||
fn from(value: Type) -> Self {
|
||||
Self::Concrete(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl TaggedType {
|
||||
pub fn r#type(&self) -> &Type {
|
||||
match self {
|
||||
TaggedType::Tagged(_, _, tagged_type) => tagged_type.r#type(),
|
||||
TaggedType::Concrete(t) => t,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn type_mut(&mut self) -> &mut Type {
|
||||
match self {
|
||||
TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_mut(),
|
||||
TaggedType::Concrete(t) => t,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_concrete(self) -> Option<Type> {
|
||||
match self {
|
||||
TaggedType::Tagged(type_tag, _, tagged_type) => None,
|
||||
TaggedType::Concrete(t) => {
|
||||
if t.is_concrete() {
|
||||
Some(t)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn matches(&self, typ: &Type) -> (bool, Option<Ident>) {
|
||||
match self {
|
||||
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||
if typ.has_tag(type_tag) {
|
||||
(true, Some(ident.clone()))
|
||||
} else {
|
||||
tagged_type.matches(typ)
|
||||
}
|
||||
}
|
||||
TaggedType::Concrete(c) => (c == typ, None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn arrow(self) -> Option<(TaggedType, TaggedType)> {
|
||||
match self {
|
||||
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||
let (lhs, rhs) = tagged_type.arrow()?;
|
||||
Some((
|
||||
TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs))
|
||||
.clear_unused_names(),
|
||||
TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(),
|
||||
))
|
||||
}
|
||||
TaggedType::Concrete(Type::Arrow(lhs, rhs)) => {
|
||||
Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs)))
|
||||
}
|
||||
TaggedType::Concrete(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType {
|
||||
match self {
|
||||
TaggedType::Tagged(_, i, tagged_type) if i == ident => {
|
||||
tagged_type.specialize(ident, typ)
|
||||
}
|
||||
TaggedType::Tagged(t, i, tagged_type) => {
|
||||
TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ)))
|
||||
}
|
||||
TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)),
|
||||
}
|
||||
}
|
||||
|
||||
fn name_used(&self, ident: &str) -> bool {
|
||||
match self {
|
||||
TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident),
|
||||
TaggedType::Concrete(c) => c.name_used(ident),
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_unused_names(self) -> TaggedType {
|
||||
match self {
|
||||
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||
if tagged_type.name_used(&ident) {
|
||||
TaggedType::Tagged(type_tag, ident, tagged_type)
|
||||
} else {
|
||||
*tagged_type
|
||||
}
|
||||
}
|
||||
t => t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PrimitiveType> for Type {
|
||||
fn from(value: PrimitiveType) -> Self {
|
||||
Self::Primitive(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self {
|
||||
Self::Arrow(Box::new(t1.into()), Box::new(t2.into()))
|
||||
}
|
||||
|
||||
pub fn is_concrete(&self) -> bool {
|
||||
match self {
|
||||
Type::Generic(_) => false,
|
||||
Type::Primitive(primitive_type) => true,
|
||||
Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_tag(&self, tag: &TypeTag) -> bool {
|
||||
match self {
|
||||
Type::Generic(_) => false,
|
||||
Type::Primitive(primitive_type) => match (primitive_type, tag) {
|
||||
(PrimitiveType::Nat, TypeTag::Num) => true,
|
||||
_ => false,
|
||||
},
|
||||
Type::Arrow(_, _) => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn name_used(&self, ident: &str) -> bool {
|
||||
match self {
|
||||
Type::Generic(i) if *i == ident => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn specialize(self, ident: &str, typ: &Type) -> Type {
|
||||
match self {
|
||||
Type::Generic(i) if i == ident => typ.clone(),
|
||||
Type::Arrow(lhs, rhs) => Type::Arrow(
|
||||
Box::new(lhs.specialize(ident, typ)),
|
||||
Box::new(rhs.specialize(ident, typ)),
|
||||
),
|
||||
t => t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TypeTag {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TypeTag::Num => write!(f, "Num"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TaggedType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||
write!(f, "{type_tag} {ident} => {tagged_type}")
|
||||
}
|
||||
TaggedType::Concrete(t) => write!(f, "{t}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Type {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Type::Generic(i) => write!(f, "{i}"),
|
||||
Type::Primitive(primitive_type) => write!(f, "{primitive_type}"),
|
||||
Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PrimitiveType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
PrimitiveType::Nat => write!(f, "Nat"),
|
||||
PrimitiveType::Bool => write!(f, "Bool"),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
use std::mem;
|
||||
|
||||
pub struct VecMap<K: Eq, V> {
|
||||
map: Vec<(K, V)>,
|
||||
}
|
||||
|
||||
impl<K: Eq, V> VecMap<K, V> {
|
||||
pub fn new() -> Self {
|
||||
Self { map: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: K, val: V) -> Option<(K, V)> {
|
||||
if let Some(entry) = self.map.iter_mut().find(|(k, _)| *k == key) {
|
||||
Some(mem::replace(entry, (key, val)))
|
||||
} else {
|
||||
self.map.push((key, val));
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, key: &K) -> Option<&V> {
|
||||
self.map
|
||||
.iter()
|
||||
.find_map(|(k, v)| if k == key { Some(v) } else { None })
|
||||
}
|
||||
|
||||
pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) {
|
||||
self.map.iter_mut().for_each(|(k, _)| f(k))
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Eq + Clone, V: Clone> Clone for VecMap<K, V> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
map: self.map.clone(),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in new issue