From 6a46c6ca52fefec790a06fb201d230e4116179d2 Mon Sep 17 00:00:00 2001 From: Avery Date: Fri, 14 Mar 2025 21:37:30 +0100 Subject: [PATCH] Type tags, fix deBrujin beta reduction, DeBrujinAst -> Ast, type inference on DeBrujinAst --- src/exec/mod.rs | 88 ++++++++++-------- src/exec/test.rs | 63 +++++++------ src/inference/mod.rs | 74 +++++++++++---- src/inference/test.rs | 84 +++++++++++------ src/lib.rs | 84 +---------------- src/main.rs | 14 ++- src/parse/mod.rs | 38 +++++++- src/parse/test.rs | 7 +- src/types/mod.rs | 211 ++++++++++++++++++++++++++++++++++++++++++ src/vec_map.rs | 38 ++++++++ 10 files changed, 498 insertions(+), 203 deletions(-) create mode 100644 src/types/mod.rs create mode 100644 src/vec_map.rs diff --git a/src/exec/mod.rs b/src/exec/mod.rs index 3ef9bc7..2e30ec8 100644 --- a/src/exec/mod.rs +++ b/src/exec/mod.rs @@ -3,52 +3,27 @@ mod test; use std::{collections::HashMap, rc::Rc}; -use crate::{Ast, Constant, Ident, Type}; +use crate::{ + Ident, Type, + parse::{Ast, Constant}, + vec_map::VecMap, +}; -impl Ast { - pub fn beta_reduce(self) -> Ast { - match self { - Ast::Application(lhs, rhs) => match *lhs { - Ast::Abstraction(var, _, ast) => ast.subst(var, *rhs).beta_reduce(), - lhs => Ast::Application(Box::new(lhs), rhs), - }, - t => t, - } - } - - fn subst(self, var: Ident, subst: Ast) -> Ast { - match self { - Ast::Abstraction(var1, typ, ast) => { - if var != var1 { - Ast::Abstraction(var1, typ, Box::new(ast.subst(var, subst))) - } else { - Ast::Abstraction(var1, typ, ast) - } - } - Ast::Application(lhs, rhs) => Ast::Application( - Box::new(lhs.subst(var.clone(), subst.clone())), - Box::new(rhs.subst(var, subst)), - ), - Ast::Variable(v) if v == var => subst, - Ast::Variable(v) => Ast::Variable(v), - Ast::Constant(constant) => Ast::Constant(constant), - } +impl Into for Ast { + fn into(self) -> DeBrujinAst { + self.to_de_brujin_inter(Rc::new(HashMap::new())) } } impl Ast { - pub fn to_de_brujin(self) -> DeBrujinAst { - self.to_de_brujin_inter(Rc::new(HashMap::new())) - } - fn to_de_brujin_inter(self, mut gamma: Rc>) -> DeBrujinAst { match self { - Ast::Abstraction(var, _, ast) => { + Ast::Abstraction(var, t, ast) => { let gamma_ref = Rc::make_mut(&mut gamma); gamma_ref.values_mut().for_each(|v| *v += 1); - gamma_ref.insert(var, 1); + gamma_ref.insert(var.clone(), 1); - DeBrujinAst::Abstraction(Box::new(ast.to_de_brujin_inter(gamma))) + DeBrujinAst::Abstraction(var, t, Box::new(ast.to_de_brujin_inter(gamma))) } Ast::Application(lhs, rhs) => DeBrujinAst::Application( Box::new(lhs.to_de_brujin_inter(gamma.clone())), @@ -68,18 +43,51 @@ impl Ast { #[derive(Debug, Clone, PartialEq, Eq)] pub enum DeBrujinAst { - Abstraction(Box), // \:1.2 + Abstraction(Ident, Type, Box), // \:1.2 Application(Box, Box), // 0 1 FreeVariable(String), // x BoundVariable(usize), // 1 Constant(Constant), // true | false | n } +impl Into for DeBrujinAst { + fn into(self) -> Ast { + self.to_ast(Rc::new(VecMap::new())) + } +} + impl DeBrujinAst { + // Vec<(usize, String)> as opposed to a HashMap here since we need to mutate + // all the keys every time recuse into an Abstraction, which would be suuper + // expensive if not impossible with a HashMap + fn to_ast(self, mut gamma: Rc>) -> Ast { + match self { + DeBrujinAst::Abstraction(i, t, de_brujin_ast) => { + let gamma_ref = Rc::make_mut(&mut gamma); + gamma_ref.map_keys(|i| *i += 1); + gamma_ref.insert(1, i.clone()); + + Ast::Abstraction(i, t, Box::new(de_brujin_ast.to_ast(gamma))) + } + DeBrujinAst::Application(lhs, rhs) => Ast::Application( + Box::new(lhs.to_ast(gamma.clone())), + Box::new(rhs.to_ast(gamma)), + ), + DeBrujinAst::FreeVariable(i) => Ast::Variable(i), + DeBrujinAst::BoundVariable(n) => Ast::Variable( + gamma + .get(&n) + .unwrap() // Compiler bug if panics + .clone(), + ), + DeBrujinAst::Constant(constant) => Ast::Constant(constant), + } + } + pub fn beta_reduce(self) -> DeBrujinAst { match self { DeBrujinAst::Application(lhs, rhs) => match *lhs { - DeBrujinAst::Abstraction(ast) => ast.subst_bound(1, *rhs).beta_reduce(), + DeBrujinAst::Abstraction(_, _, ast) => ast.subst_bound(1, *rhs).beta_reduce(), lhs => DeBrujinAst::Application(Box::new(lhs), rhs), }, a => a, @@ -88,7 +96,9 @@ impl DeBrujinAst { fn subst_bound(self, depth: usize, subst: DeBrujinAst) -> DeBrujinAst { match self { - DeBrujinAst::Abstraction(ast) => ast.subst_bound(depth + 1, subst), + DeBrujinAst::Abstraction(i, t, ast) => { + DeBrujinAst::Abstraction(i, t, Box::new(ast.subst_bound(depth + 1, subst))) + } DeBrujinAst::Application(lhs, rhs) => DeBrujinAst::Application( Box::new(lhs.subst_bound(depth, subst.clone())), Box::new(rhs.subst_bound(depth, subst)), diff --git a/src/exec/test.rs b/src/exec/test.rs index bfd8256..f638bbf 100644 --- a/src/exec/test.rs +++ b/src/exec/test.rs @@ -1,27 +1,8 @@ -use crate::{Ast, Constant, PrimitiveType, Type, exec::DeBrujinAst as DBAst}; - -#[test] -fn beta_reduce() { - let input = Ast::Application( - Box::new(Ast::Abstraction( - "x".to_string(), - Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), - Box::new(Ast::Application( - Box::new(Ast::Variable("x".to_string())), - Box::new(Ast::Constant(Constant::Nat(5))), - )), - )), - Box::new(Ast::Variable("y".to_string())), - ); - let reduced = input.beta_reduce(); - assert_eq!( - reduced, - Ast::Application( - Box::new(Ast::Variable("y".to_string())), - Box::new(Ast::Constant(Constant::Nat(5))), - ), - ) -} +use crate::{ + PrimitiveType, Type, + exec::DeBrujinAst as DBAst, + parse::{Ast, Constant}, +}; #[test] fn to_de_brujin_ast_simple() { @@ -34,12 +15,18 @@ fn to_de_brujin_ast_simple() { Box::new(Ast::Variable("x".to_string())), )), ); - let de_brujin = input.to_de_brujin(); + let de_brujin: DBAst = input.into(); assert_eq!( de_brujin, - DBAst::Abstraction(Box::new(DBAst::Abstraction(Box::new( - DBAst::BoundVariable(1) - )))) + DBAst::Abstraction( + "x".to_string(), + PrimitiveType::Nat.into(), + Box::new(DBAst::Abstraction( + "x".to_string(), + PrimitiveType::Nat.into(), + Box::new(DBAst::BoundVariable(1)) + )) + ) ) } @@ -56,7 +43,7 @@ fn de_brujin_beta_reduce() { )), Box::new(Ast::Variable("y".to_string())), ); - let dbast = input.to_de_brujin(); + let dbast: DBAst = input.into(); let reduced = dbast.beta_reduce(); assert_eq!( reduced, @@ -66,3 +53,21 @@ fn de_brujin_beta_reduce() { ), ) } + +#[test] +fn to_and_from_de_brujin_is_id() { + let input = Ast::Application( + Box::new(Ast::Abstraction( + "x".to_string(), + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat), + Box::new(Ast::Application( + Box::new(Ast::Variable("x".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + )), + Box::new(Ast::Variable("y".to_string())), + ); + let dbast: DBAst = input.clone().into(); + let output: Ast = dbast.into(); + assert_eq!(input, output); +} diff --git a/src/inference/mod.rs b/src/inference/mod.rs index 37c4fdf..204c348 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -1,37 +1,75 @@ -use std::{collections::HashMap, error::Error, rc::Rc}; +use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc}; -use crate::{Ast, Constant, Ident, PrimitiveType, Type}; +use crate::{ + DeBrujinAst, Ident, PrimitiveType, Type, + parse::{Ast, Constant}, + types::{TaggedType, TypeTag}, + vec_map::VecMap, +}; #[derive(Debug)] pub enum InferError { NotAFunction, MismatchedType, NotInContext, + ExpectedConreteType, + ExpectedTypeWithTag, } #[cfg(test)] mod test; -pub fn infer_type(mut gamma: Rc>, ast: Ast) -> Result { +pub fn infer_type( + gamma: &HashMap, + ast: DeBrujinAst, +) -> Result { + infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast) +} + +fn infer_type_debrujin_int( + gamma_free: &HashMap, + mut gamma_bound: Rc>, + ast: DeBrujinAst, +) -> Result { match ast { - Ast::Abstraction(arg, arg_type, ast) => { - Rc::make_mut(&mut gamma).insert(arg, arg_type.clone()); - let out_type = infer_type(gamma, *ast)?; - Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type))) + DeBrujinAst::Abstraction(_, arg_type, ast) => { + let gamma_ref = Rc::make_mut(&mut gamma_bound); + gamma_ref.map_keys(|i| *i += 1); + gamma_ref.insert(1, arg_type.clone().into()); + + let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?; + let typ = Type::Arrow( + Box::new(arg_type), + Box::new( + out_type + .to_concrete() + .ok_or(InferError::ExpectedConreteType)?, + ), + ); + Ok(typ.into()) } - Ast::Application(left, right) => { - let left_type = infer_type(gamma.clone(), *left)?; - let Type::Arrow(in_type, out_type) = left_type else { + DeBrujinAst::Application(lhs, rhs) => { + let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?; + let Some((in_type, out_type)) = left_type.arrow() else { return Err(InferError::NotAFunction); }; - let right_type = infer_type(gamma, *right)?; - if *in_type != right_type { - return Err(InferError::MismatchedType); - } - Ok(*out_type) + let Some(right_type) = + infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete() + else { + return Err(InferError::ExpectedConreteType); + }; + let out_type = match in_type.matches(&right_type) { + (true, None) => out_type, + (true, Some(a)) => out_type.specialize(&a, &right_type), + (false, _) => return Err(InferError::MismatchedType), + }; + Ok(out_type) } - Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext), - Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)), - Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)), + DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext), + // compiler bug if not present + DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()), + DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()), + DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()), + DeBrujinAst::Constant(_) => unreachable!(), } } diff --git a/src/inference/test.rs b/src/inference/test.rs index 37fe4eb..ad50c5e 100644 --- a/src/inference/test.rs +++ b/src/inference/test.rs @@ -1,49 +1,73 @@ use std::{collections::HashMap, rc::Rc}; -use crate::{Ast, Constant, PrimitiveType, Type}; - -use super::infer_type; +use crate::{ + DeBrujinAst, PrimitiveType, Type, + inference::infer_type, + parse::{Ast, Constant}, + types::{TaggedType, TypeTag}, +}; #[test] -fn infer_id_type() { - let ast = Ast::Abstraction( - "x".to_string(), - Type::Primitive(PrimitiveType::Nat), - Box::new(Ast::Variable("x".to_string())), +fn infer_add_nat() { + let ast: DeBrujinAst = Ast::Application( + Box::new(Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + Box::new(Ast::Constant(Constant::Nat(7))), + ) + .into(); + + let mut gamma = HashMap::new(); + gamma.insert( + "add".to_string(), + crate::types::TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::Arrow( + Box::new(Type::Generic("a".to_string())), + Box::new(Type::Arrow( + Box::new(Type::Generic("a".to_string())), + Box::new(Type::Generic("a".to_string())), + )), + ))), + ), ); - let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap(); + let infered = infer_type(&gamma, ast).unwrap(); assert_eq!( infered, - Type::Arrow( - Box::new(Type::Primitive(PrimitiveType::Nat)), - Box::new(Type::Primitive(PrimitiveType::Nat)) - ) - ) + TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat)) + ); } #[test] -fn infer_addition_result_type() { - let ast = Ast::Application( - Box::new(Ast::Application( - Box::new(Ast::Variable("add".to_string())), - Box::new(Ast::Constant(Constant::Nat(5))), - )), - Box::new(Ast::Constant(Constant::Nat(7))), - ); +fn infer_add_nat_partial() { + let ast: DeBrujinAst = Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + ) + .into(); let mut gamma = HashMap::new(); gamma.insert( "add".to_string(), - Type::Arrow( - Box::new(Type::Primitive(PrimitiveType::Nat)), - Box::new(Type::Arrow( - Box::new(Type::Primitive(PrimitiveType::Nat)), - Box::new(Type::Primitive(PrimitiveType::Nat)), - )), + crate::types::TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::Arrow( + Box::new(Type::Generic("a".to_string())), + Box::new(Type::Arrow( + Box::new(Type::Generic("a".to_string())), + Box::new(Type::Generic("a".to_string())), + )), + ))), ), ); - let infered = infer_type(Rc::new(gamma), ast).unwrap(); - assert_eq!(infered, Type::Primitive(PrimitiveType::Nat)); + let infered = infer_type(&gamma, ast).unwrap(); + assert_eq!( + infered, + TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat)) + ); } diff --git a/src/lib.rs b/src/lib.rs index 07fc847..3153f3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,84 +5,10 @@ use std::fmt::Display; mod exec; mod inference; mod parse; +mod types; +mod vec_map; +pub use exec::DeBrujinAst; pub use inference::infer_type; -pub use parse::{ParseError, is_ident, parse, parse_type, sexpr::parse_string}; - -type Ident = String; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PrimitiveType { - Nat, - Bool, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Type { - Primitive(PrimitiveType), // Bool | Nat - Arrow(Box, Box), // 0 -> 1 -} - -impl From for Type { - fn from(value: PrimitiveType) -> Self { - Type::Primitive(value) - } -} - -impl Type { - fn arrow, T2: Into>(t1: T1, t2: T2) -> Self { - Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Constant { - Nat(usize), - Bool(bool), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Ast { - Abstraction(Ident, Type, Box), // \0:1.2 - Application(Box, Box), // 0 1 - Variable(Ident), // x - Constant(Constant), // true | false | n -} - -impl Display for Ast { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"), - Ast::Application(lhs, rhs) => write!(f, "({lhs} {rhs})"), - Ast::Variable(v) => write!(f, "{v}"), - Ast::Constant(constant) => write!(f, "{constant}"), - } - } -} - -impl Display for Type { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), - Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), - } - } -} - -impl Display for PrimitiveType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PrimitiveType::Nat => write!(f, "Nat"), - PrimitiveType::Bool => write!(f, "Bool"), - } - } -} - -impl Display for Constant { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Constant::Nat(n) => write!(f, "{n}"), - Constant::Bool(b) => write!(f, "{b}"), - } - } -} +pub use parse::{Ast, ParseError, is_ident, parse, parse_type, sexpr::parse_string}; +use types::{Ident, PrimitiveType, Type}; diff --git a/src/main.rs b/src/main.rs index fee1d16..a7cfaa3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,9 @@ use std::{ rc::Rc, }; -use stlc_type_inference::{infer_type, is_ident, parse, parse_string, parse_type}; +use stlc_type_inference::{ + Ast, DeBrujinAst, infer_type, is_ident, parse, parse_string, parse_type, +}; macro_rules! repl_err { ($err:expr) => {{ @@ -16,7 +18,7 @@ macro_rules! repl_err { } fn main() { - let mut gamma = Rc::new(HashMap::new()); + let mut gamma = HashMap::new(); print!("> "); stdout().flush().unwrap(); for line in stdin().lines() { @@ -25,7 +27,7 @@ fn main() { if let Some((cmd, expr)) = tl.split_once(' ') { match cmd { "t" => match parse(expr) { - Ok(a) => match infer_type(gamma.clone(), a) { + Ok(a) => match infer_type(&gamma, a.into()) { Ok(t) => println!("{t}"), Err(e) => repl_err!("Could not infer type {e:?}"), }, @@ -49,7 +51,7 @@ fn main() { }; println!("Added {ident} with type {typ} to the context"); - Rc::make_mut(&mut gamma).insert(ident, typ); + gamma.insert(ident, typ.into()); } } c => println!("Unknown command {c}"), @@ -67,11 +69,13 @@ fn main() { Err(e) => repl_err!("Parse error {e:?}"), }; - let typ = match infer_type(gamma.clone(), ast.clone()) { + let ast: DeBrujinAst = ast.into(); + let typ = match infer_type(&gamma, ast.clone()) { Ok(t) => t, Err(e) => repl_err!("Could not infer type {e:?}"), }; let ast = ast.beta_reduce(); + let ast: Ast = ast.into(); println!("{ast} : {typ}") } print!("> "); diff --git a/src/parse/mod.rs b/src/parse/mod.rs index 65edc05..f29703b 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -1,6 +1,8 @@ +use std::fmt::Display; + use sexpr::{Sexpr, parse_string}; -use crate::{Ast, Constant, PrimitiveType, Type}; +use crate::{PrimitiveType, Type, types::Ident}; pub mod sexpr; #[cfg(test)] @@ -27,6 +29,40 @@ pub enum ParseError { ExpectedOneOf(Vec, String), } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Constant { + Nat(usize), + Bool(bool), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Ast { + Abstraction(Ident, Type, Box), // \0:1.2 + Application(Box, Box), // 0 1 + Variable(Ident), // x + Constant(Constant), // true | false | n +} + +impl Display for Ast { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"), + Ast::Application(lhs, rhs) => write!(f, "({lhs} {rhs})"), + Ast::Variable(v) => write!(f, "{v}"), + Ast::Constant(constant) => write!(f, "{constant}"), + } + } +} + +impl Display for Constant { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Constant::Nat(n) => write!(f, "{n}"), + Constant::Bool(b) => write!(f, "{b}"), + } + } +} + fn expect_symbol(ast: Option) -> Result { match ast { Some(Sexpr::Symbol(s)) => Ok(s), diff --git a/src/parse/test.rs b/src/parse/test.rs index 61f2ef7..a250a00 100644 --- a/src/parse/test.rs +++ b/src/parse/test.rs @@ -1,6 +1,9 @@ use crate::{ - Ast, Constant, PrimitiveType, Type, - parse::sexpr::{Sexpr, parse_string}, + PrimitiveType, Type, + parse::{ + Ast, Constant, + sexpr::{Sexpr, parse_string}, + }, }; use super::{parse, parse_type}; diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..a274778 --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,211 @@ +use std::{fmt::Display, str}; + +pub type Ident = String; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PrimitiveType { + Nat, + Bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TypeTag { + Num, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TaggedType { + Tagged(TypeTag, Ident, Box), + Concrete(Type), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Type { + Generic(Ident), // a + Primitive(PrimitiveType), // Bool | Nat + Arrow(Box, Box), // 0 -> 1 +} + +impl From for TaggedType { + fn from(value: Type) -> Self { + Self::Concrete(value) + } +} + +impl TaggedType { + pub fn r#type(&self) -> &Type { + match self { + TaggedType::Tagged(_, _, tagged_type) => tagged_type.r#type(), + TaggedType::Concrete(t) => t, + } + } + + pub fn type_mut(&mut self) -> &mut Type { + match self { + TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_mut(), + TaggedType::Concrete(t) => t, + } + } + + pub fn to_concrete(self) -> Option { + match self { + TaggedType::Tagged(type_tag, _, tagged_type) => None, + TaggedType::Concrete(t) => { + if t.is_concrete() { + Some(t) + } else { + None + } + } + } + } + + pub fn matches(&self, typ: &Type) -> (bool, Option) { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + if typ.has_tag(type_tag) { + (true, Some(ident.clone())) + } else { + tagged_type.matches(typ) + } + } + TaggedType::Concrete(c) => (c == typ, None), + } + } + + pub fn arrow(self) -> Option<(TaggedType, TaggedType)> { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + let (lhs, rhs) = tagged_type.arrow()?; + Some(( + TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs)) + .clear_unused_names(), + TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(), + )) + } + TaggedType::Concrete(Type::Arrow(lhs, rhs)) => { + Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs))) + } + TaggedType::Concrete(_) => None, + } + } + + pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType { + match self { + TaggedType::Tagged(_, i, tagged_type) if i == ident => { + tagged_type.specialize(ident, typ) + } + TaggedType::Tagged(t, i, tagged_type) => { + TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ))) + } + TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)), + } + } + + fn name_used(&self, ident: &str) -> bool { + match self { + TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident), + TaggedType::Concrete(c) => c.name_used(ident), + } + } + + fn clear_unused_names(self) -> TaggedType { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + if tagged_type.name_used(&ident) { + TaggedType::Tagged(type_tag, ident, tagged_type) + } else { + *tagged_type + } + } + t => t, + } + } +} + +impl From for Type { + fn from(value: PrimitiveType) -> Self { + Self::Primitive(value) + } +} + +impl Type { + pub fn arrow, T2: Into>(t1: T1, t2: T2) -> Self { + Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) + } + + pub fn is_concrete(&self) -> bool { + match self { + Type::Generic(_) => false, + Type::Primitive(primitive_type) => true, + Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(), + } + } + + pub fn has_tag(&self, tag: &TypeTag) -> bool { + match self { + Type::Generic(_) => false, + Type::Primitive(primitive_type) => match (primitive_type, tag) { + (PrimitiveType::Nat, TypeTag::Num) => true, + _ => false, + }, + Type::Arrow(_, _) => false, + } + } + + fn name_used(&self, ident: &str) -> bool { + match self { + Type::Generic(i) if *i == ident => true, + _ => false, + } + } + + fn specialize(self, ident: &str, typ: &Type) -> Type { + match self { + Type::Generic(i) if i == ident => typ.clone(), + Type::Arrow(lhs, rhs) => Type::Arrow( + Box::new(lhs.specialize(ident, typ)), + Box::new(rhs.specialize(ident, typ)), + ), + t => t, + } + } +} + +impl Display for TypeTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeTag::Num => write!(f, "Num"), + } + } +} + +impl Display for TaggedType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + write!(f, "{type_tag} {ident} => {tagged_type}") + } + TaggedType::Concrete(t) => write!(f, "{t}"), + } + } +} + +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Generic(i) => write!(f, "{i}"), + Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), + Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), + } + } +} + +impl Display for PrimitiveType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PrimitiveType::Nat => write!(f, "Nat"), + PrimitiveType::Bool => write!(f, "Bool"), + } + } +} diff --git a/src/vec_map.rs b/src/vec_map.rs new file mode 100644 index 0000000..dde0f68 --- /dev/null +++ b/src/vec_map.rs @@ -0,0 +1,38 @@ +use std::mem; + +pub struct VecMap { + map: Vec<(K, V)>, +} + +impl VecMap { + pub fn new() -> Self { + Self { map: Vec::new() } + } + + pub fn insert(&mut self, key: K, val: V) -> Option<(K, V)> { + if let Some(entry) = self.map.iter_mut().find(|(k, _)| *k == key) { + Some(mem::replace(entry, (key, val))) + } else { + self.map.push((key, val)); + None + } + } + + pub fn get(&self, key: &K) -> Option<&V> { + self.map + .iter() + .find_map(|(k, v)| if k == key { Some(v) } else { None }) + } + + pub fn map_keys(&mut self, f: F) { + self.map.iter_mut().for_each(|(k, _)| f(k)) + } +} + +impl Clone for VecMap { + fn clone(&self) -> Self { + Self { + map: self.map.clone(), + } + } +}