parent
134f803b88
commit
6a46c6ca52
@ -1,37 +1,75 @@
|
|||||||
use std::{collections::HashMap, error::Error, rc::Rc};
|
use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc};
|
||||||
|
|
||||||
use crate::{Ast, Constant, Ident, PrimitiveType, Type};
|
use crate::{
|
||||||
|
DeBrujinAst, Ident, PrimitiveType, Type,
|
||||||
|
parse::{Ast, Constant},
|
||||||
|
types::{TaggedType, TypeTag},
|
||||||
|
vec_map::VecMap,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum InferError {
|
pub enum InferError {
|
||||||
NotAFunction,
|
NotAFunction,
|
||||||
MismatchedType,
|
MismatchedType,
|
||||||
NotInContext,
|
NotInContext,
|
||||||
|
ExpectedConreteType,
|
||||||
|
ExpectedTypeWithTag,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
pub fn infer_type(mut gamma: Rc<HashMap<Ident, Type>>, ast: Ast) -> Result<Type, InferError> {
|
pub fn infer_type(
|
||||||
|
gamma: &HashMap<Ident, TaggedType>,
|
||||||
|
ast: DeBrujinAst,
|
||||||
|
) -> Result<TaggedType, InferError> {
|
||||||
|
infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_type_debrujin_int(
|
||||||
|
gamma_free: &HashMap<Ident, TaggedType>,
|
||||||
|
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
|
||||||
|
ast: DeBrujinAst,
|
||||||
|
) -> Result<TaggedType, InferError> {
|
||||||
match ast {
|
match ast {
|
||||||
Ast::Abstraction(arg, arg_type, ast) => {
|
DeBrujinAst::Abstraction(_, arg_type, ast) => {
|
||||||
Rc::make_mut(&mut gamma).insert(arg, arg_type.clone());
|
let gamma_ref = Rc::make_mut(&mut gamma_bound);
|
||||||
let out_type = infer_type(gamma, *ast)?;
|
gamma_ref.map_keys(|i| *i += 1);
|
||||||
Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type)))
|
gamma_ref.insert(1, arg_type.clone().into());
|
||||||
|
|
||||||
|
let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?;
|
||||||
|
let typ = Type::Arrow(
|
||||||
|
Box::new(arg_type),
|
||||||
|
Box::new(
|
||||||
|
out_type
|
||||||
|
.to_concrete()
|
||||||
|
.ok_or(InferError::ExpectedConreteType)?,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
Ok(typ.into())
|
||||||
}
|
}
|
||||||
Ast::Application(left, right) => {
|
DeBrujinAst::Application(lhs, rhs) => {
|
||||||
let left_type = infer_type(gamma.clone(), *left)?;
|
let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?;
|
||||||
let Type::Arrow(in_type, out_type) = left_type else {
|
let Some((in_type, out_type)) = left_type.arrow() else {
|
||||||
return Err(InferError::NotAFunction);
|
return Err(InferError::NotAFunction);
|
||||||
};
|
};
|
||||||
let right_type = infer_type(gamma, *right)?;
|
let Some(right_type) =
|
||||||
if *in_type != right_type {
|
infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete()
|
||||||
return Err(InferError::MismatchedType);
|
else {
|
||||||
}
|
return Err(InferError::ExpectedConreteType);
|
||||||
Ok(*out_type)
|
};
|
||||||
|
let out_type = match in_type.matches(&right_type) {
|
||||||
|
(true, None) => out_type,
|
||||||
|
(true, Some(a)) => out_type.specialize(&a, &right_type),
|
||||||
|
(false, _) => return Err(InferError::MismatchedType),
|
||||||
|
};
|
||||||
|
Ok(out_type)
|
||||||
}
|
}
|
||||||
Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext),
|
DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext),
|
||||||
Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)),
|
// compiler bug if not present
|
||||||
Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)),
|
DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()),
|
||||||
|
DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()),
|
||||||
|
DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()),
|
||||||
|
DeBrujinAst::Constant(_) => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,49 +1,73 @@
|
|||||||
use std::{collections::HashMap, rc::Rc};
|
use std::{collections::HashMap, rc::Rc};
|
||||||
|
|
||||||
use crate::{Ast, Constant, PrimitiveType, Type};
|
use crate::{
|
||||||
|
DeBrujinAst, PrimitiveType, Type,
|
||||||
use super::infer_type;
|
inference::infer_type,
|
||||||
|
parse::{Ast, Constant},
|
||||||
|
types::{TaggedType, TypeTag},
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn infer_id_type() {
|
fn infer_add_nat() {
|
||||||
let ast = Ast::Abstraction(
|
let ast: DeBrujinAst = Ast::Application(
|
||||||
"x".to_string(),
|
Box::new(Ast::Application(
|
||||||
Type::Primitive(PrimitiveType::Nat),
|
Box::new(Ast::Variable("add".to_string())),
|
||||||
Box::new(Ast::Variable("x".to_string())),
|
Box::new(Ast::Constant(Constant::Nat(5))),
|
||||||
|
)),
|
||||||
|
Box::new(Ast::Constant(Constant::Nat(7))),
|
||||||
|
)
|
||||||
|
.into();
|
||||||
|
|
||||||
|
let mut gamma = HashMap::new();
|
||||||
|
gamma.insert(
|
||||||
|
"add".to_string(),
|
||||||
|
crate::types::TaggedType::Tagged(
|
||||||
|
TypeTag::Num,
|
||||||
|
"a".to_string(),
|
||||||
|
Box::new(TaggedType::Concrete(Type::Arrow(
|
||||||
|
Box::new(Type::Generic("a".to_string())),
|
||||||
|
Box::new(Type::Arrow(
|
||||||
|
Box::new(Type::Generic("a".to_string())),
|
||||||
|
Box::new(Type::Generic("a".to_string())),
|
||||||
|
)),
|
||||||
|
))),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap();
|
let infered = infer_type(&gamma, ast).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
infered,
|
infered,
|
||||||
Type::Arrow(
|
TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat))
|
||||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
);
|
||||||
Box::new(Type::Primitive(PrimitiveType::Nat))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn infer_addition_result_type() {
|
fn infer_add_nat_partial() {
|
||||||
let ast = Ast::Application(
|
let ast: DeBrujinAst = Ast::Application(
|
||||||
Box::new(Ast::Application(
|
|
||||||
Box::new(Ast::Variable("add".to_string())),
|
Box::new(Ast::Variable("add".to_string())),
|
||||||
Box::new(Ast::Constant(Constant::Nat(5))),
|
Box::new(Ast::Constant(Constant::Nat(5))),
|
||||||
)),
|
)
|
||||||
Box::new(Ast::Constant(Constant::Nat(7))),
|
.into();
|
||||||
);
|
|
||||||
|
|
||||||
let mut gamma = HashMap::new();
|
let mut gamma = HashMap::new();
|
||||||
gamma.insert(
|
gamma.insert(
|
||||||
"add".to_string(),
|
"add".to_string(),
|
||||||
Type::Arrow(
|
crate::types::TaggedType::Tagged(
|
||||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
TypeTag::Num,
|
||||||
|
"a".to_string(),
|
||||||
|
Box::new(TaggedType::Concrete(Type::Arrow(
|
||||||
|
Box::new(Type::Generic("a".to_string())),
|
||||||
Box::new(Type::Arrow(
|
Box::new(Type::Arrow(
|
||||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
Box::new(Type::Generic("a".to_string())),
|
||||||
Box::new(Type::Primitive(PrimitiveType::Nat)),
|
Box::new(Type::Generic("a".to_string())),
|
||||||
)),
|
)),
|
||||||
|
))),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
let infered = infer_type(Rc::new(gamma), ast).unwrap();
|
let infered = infer_type(&gamma, ast).unwrap();
|
||||||
assert_eq!(infered, Type::Primitive(PrimitiveType::Nat));
|
assert_eq!(
|
||||||
|
infered,
|
||||||
|
TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat))
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,211 @@
|
|||||||
|
use std::{fmt::Display, str};
|
||||||
|
|
||||||
|
pub type Ident = String;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PrimitiveType {
|
||||||
|
Nat,
|
||||||
|
Bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum TypeTag {
|
||||||
|
Num,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum TaggedType {
|
||||||
|
Tagged(TypeTag, Ident, Box<TaggedType>),
|
||||||
|
Concrete(Type),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum Type {
|
||||||
|
Generic(Ident), // a
|
||||||
|
Primitive(PrimitiveType), // Bool | Nat
|
||||||
|
Arrow(Box<Type>, Box<Type>), // 0 -> 1
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Type> for TaggedType {
|
||||||
|
fn from(value: Type) -> Self {
|
||||||
|
Self::Concrete(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TaggedType {
|
||||||
|
pub fn r#type(&self) -> &Type {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(_, _, tagged_type) => tagged_type.r#type(),
|
||||||
|
TaggedType::Concrete(t) => t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn type_mut(&mut self) -> &mut Type {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_mut(),
|
||||||
|
TaggedType::Concrete(t) => t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_concrete(self) -> Option<Type> {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(type_tag, _, tagged_type) => None,
|
||||||
|
TaggedType::Concrete(t) => {
|
||||||
|
if t.is_concrete() {
|
||||||
|
Some(t)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn matches(&self, typ: &Type) -> (bool, Option<Ident>) {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||||
|
if typ.has_tag(type_tag) {
|
||||||
|
(true, Some(ident.clone()))
|
||||||
|
} else {
|
||||||
|
tagged_type.matches(typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TaggedType::Concrete(c) => (c == typ, None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn arrow(self) -> Option<(TaggedType, TaggedType)> {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||||
|
let (lhs, rhs) = tagged_type.arrow()?;
|
||||||
|
Some((
|
||||||
|
TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs))
|
||||||
|
.clear_unused_names(),
|
||||||
|
TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
TaggedType::Concrete(Type::Arrow(lhs, rhs)) => {
|
||||||
|
Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs)))
|
||||||
|
}
|
||||||
|
TaggedType::Concrete(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(_, i, tagged_type) if i == ident => {
|
||||||
|
tagged_type.specialize(ident, typ)
|
||||||
|
}
|
||||||
|
TaggedType::Tagged(t, i, tagged_type) => {
|
||||||
|
TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ)))
|
||||||
|
}
|
||||||
|
TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name_used(&self, ident: &str) -> bool {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident),
|
||||||
|
TaggedType::Concrete(c) => c.name_used(ident),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_unused_names(self) -> TaggedType {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||||
|
if tagged_type.name_used(&ident) {
|
||||||
|
TaggedType::Tagged(type_tag, ident, tagged_type)
|
||||||
|
} else {
|
||||||
|
*tagged_type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t => t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PrimitiveType> for Type {
|
||||||
|
fn from(value: PrimitiveType) -> Self {
|
||||||
|
Self::Primitive(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Type {
|
||||||
|
pub fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self {
|
||||||
|
Self::Arrow(Box::new(t1.into()), Box::new(t2.into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_concrete(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Type::Generic(_) => false,
|
||||||
|
Type::Primitive(primitive_type) => true,
|
||||||
|
Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_tag(&self, tag: &TypeTag) -> bool {
|
||||||
|
match self {
|
||||||
|
Type::Generic(_) => false,
|
||||||
|
Type::Primitive(primitive_type) => match (primitive_type, tag) {
|
||||||
|
(PrimitiveType::Nat, TypeTag::Num) => true,
|
||||||
|
_ => false,
|
||||||
|
},
|
||||||
|
Type::Arrow(_, _) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name_used(&self, ident: &str) -> bool {
|
||||||
|
match self {
|
||||||
|
Type::Generic(i) if *i == ident => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn specialize(self, ident: &str, typ: &Type) -> Type {
|
||||||
|
match self {
|
||||||
|
Type::Generic(i) if i == ident => typ.clone(),
|
||||||
|
Type::Arrow(lhs, rhs) => Type::Arrow(
|
||||||
|
Box::new(lhs.specialize(ident, typ)),
|
||||||
|
Box::new(rhs.specialize(ident, typ)),
|
||||||
|
),
|
||||||
|
t => t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for TypeTag {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TypeTag::Num => write!(f, "Num"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for TaggedType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
TaggedType::Tagged(type_tag, ident, tagged_type) => {
|
||||||
|
write!(f, "{type_tag} {ident} => {tagged_type}")
|
||||||
|
}
|
||||||
|
TaggedType::Concrete(t) => write!(f, "{t}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for Type {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Type::Generic(i) => write!(f, "{i}"),
|
||||||
|
Type::Primitive(primitive_type) => write!(f, "{primitive_type}"),
|
||||||
|
Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for PrimitiveType {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
PrimitiveType::Nat => write!(f, "Nat"),
|
||||||
|
PrimitiveType::Bool => write!(f, "Bool"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
use std::mem;
|
||||||
|
|
||||||
|
pub struct VecMap<K: Eq, V> {
|
||||||
|
map: Vec<(K, V)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K: Eq, V> VecMap<K, V> {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self { map: Vec::new() }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, key: K, val: V) -> Option<(K, V)> {
|
||||||
|
if let Some(entry) = self.map.iter_mut().find(|(k, _)| *k == key) {
|
||||||
|
Some(mem::replace(entry, (key, val)))
|
||||||
|
} else {
|
||||||
|
self.map.push((key, val));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, key: &K) -> Option<&V> {
|
||||||
|
self.map
|
||||||
|
.iter()
|
||||||
|
.find_map(|(k, v)| if k == key { Some(v) } else { None })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) {
|
||||||
|
self.map.iter_mut().for_each(|(k, _)| f(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<K: Eq + Clone, V: Clone> Clone for VecMap<K, V> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
map: self.map.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue