Type tags, fix deBrujin beta reduction, DeBrujinAst -> Ast, type inference on DeBrujinAst

main
Avery 1 month ago
parent 134f803b88
commit 6a46c6ca52
Signed by: Avery
GPG Key ID: 4E53F4CB69B2CC8D

@ -3,52 +3,27 @@ mod test;
use std::{collections::HashMap, rc::Rc}; use std::{collections::HashMap, rc::Rc};
use crate::{Ast, Constant, Ident, Type}; use crate::{
Ident, Type,
parse::{Ast, Constant},
vec_map::VecMap,
};
impl Ast { impl Into<DeBrujinAst> for Ast {
pub fn beta_reduce(self) -> Ast { fn into(self) -> DeBrujinAst {
match self { self.to_de_brujin_inter(Rc::new(HashMap::new()))
Ast::Application(lhs, rhs) => match *lhs {
Ast::Abstraction(var, _, ast) => ast.subst(var, *rhs).beta_reduce(),
lhs => Ast::Application(Box::new(lhs), rhs),
},
t => t,
}
}
fn subst(self, var: Ident, subst: Ast) -> Ast {
match self {
Ast::Abstraction(var1, typ, ast) => {
if var != var1 {
Ast::Abstraction(var1, typ, Box::new(ast.subst(var, subst)))
} else {
Ast::Abstraction(var1, typ, ast)
}
}
Ast::Application(lhs, rhs) => Ast::Application(
Box::new(lhs.subst(var.clone(), subst.clone())),
Box::new(rhs.subst(var, subst)),
),
Ast::Variable(v) if v == var => subst,
Ast::Variable(v) => Ast::Variable(v),
Ast::Constant(constant) => Ast::Constant(constant),
}
} }
} }
impl Ast { impl Ast {
pub fn to_de_brujin(self) -> DeBrujinAst {
self.to_de_brujin_inter(Rc::new(HashMap::new()))
}
fn to_de_brujin_inter(self, mut gamma: Rc<HashMap<String, usize>>) -> DeBrujinAst { fn to_de_brujin_inter(self, mut gamma: Rc<HashMap<String, usize>>) -> DeBrujinAst {
match self { match self {
Ast::Abstraction(var, _, ast) => { Ast::Abstraction(var, t, ast) => {
let gamma_ref = Rc::make_mut(&mut gamma); let gamma_ref = Rc::make_mut(&mut gamma);
gamma_ref.values_mut().for_each(|v| *v += 1); gamma_ref.values_mut().for_each(|v| *v += 1);
gamma_ref.insert(var, 1); gamma_ref.insert(var.clone(), 1);
DeBrujinAst::Abstraction(Box::new(ast.to_de_brujin_inter(gamma))) DeBrujinAst::Abstraction(var, t, Box::new(ast.to_de_brujin_inter(gamma)))
} }
Ast::Application(lhs, rhs) => DeBrujinAst::Application( Ast::Application(lhs, rhs) => DeBrujinAst::Application(
Box::new(lhs.to_de_brujin_inter(gamma.clone())), Box::new(lhs.to_de_brujin_inter(gamma.clone())),
@ -68,18 +43,51 @@ impl Ast {
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeBrujinAst { pub enum DeBrujinAst {
Abstraction(Box<DeBrujinAst>), // \:1.2 Abstraction(Ident, Type, Box<DeBrujinAst>), // \:1.2
Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1 Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1
FreeVariable(String), // x FreeVariable(String), // x
BoundVariable(usize), // 1 BoundVariable(usize), // 1
Constant(Constant), // true | false | n Constant(Constant), // true | false | n
} }
impl Into<Ast> for DeBrujinAst {
fn into(self) -> Ast {
self.to_ast(Rc::new(VecMap::new()))
}
}
impl DeBrujinAst { impl DeBrujinAst {
// Vec<(usize, String)> as opposed to a HashMap here since we need to mutate
// all the keys every time recuse into an Abstraction, which would be suuper
// expensive if not impossible with a HashMap
fn to_ast(self, mut gamma: Rc<VecMap<usize, String>>) -> Ast {
match self {
DeBrujinAst::Abstraction(i, t, de_brujin_ast) => {
let gamma_ref = Rc::make_mut(&mut gamma);
gamma_ref.map_keys(|i| *i += 1);
gamma_ref.insert(1, i.clone());
Ast::Abstraction(i, t, Box::new(de_brujin_ast.to_ast(gamma)))
}
DeBrujinAst::Application(lhs, rhs) => Ast::Application(
Box::new(lhs.to_ast(gamma.clone())),
Box::new(rhs.to_ast(gamma)),
),
DeBrujinAst::FreeVariable(i) => Ast::Variable(i),
DeBrujinAst::BoundVariable(n) => Ast::Variable(
gamma
.get(&n)
.unwrap() // Compiler bug if panics
.clone(),
),
DeBrujinAst::Constant(constant) => Ast::Constant(constant),
}
}
pub fn beta_reduce(self) -> DeBrujinAst { pub fn beta_reduce(self) -> DeBrujinAst {
match self { match self {
DeBrujinAst::Application(lhs, rhs) => match *lhs { DeBrujinAst::Application(lhs, rhs) => match *lhs {
DeBrujinAst::Abstraction(ast) => ast.subst_bound(1, *rhs).beta_reduce(), DeBrujinAst::Abstraction(_, _, ast) => ast.subst_bound(1, *rhs).beta_reduce(),
lhs => DeBrujinAst::Application(Box::new(lhs), rhs), lhs => DeBrujinAst::Application(Box::new(lhs), rhs),
}, },
a => a, a => a,
@ -88,7 +96,9 @@ impl DeBrujinAst {
fn subst_bound(self, depth: usize, subst: DeBrujinAst) -> DeBrujinAst { fn subst_bound(self, depth: usize, subst: DeBrujinAst) -> DeBrujinAst {
match self { match self {
DeBrujinAst::Abstraction(ast) => ast.subst_bound(depth + 1, subst), DeBrujinAst::Abstraction(i, t, ast) => {
DeBrujinAst::Abstraction(i, t, Box::new(ast.subst_bound(depth + 1, subst)))
}
DeBrujinAst::Application(lhs, rhs) => DeBrujinAst::Application( DeBrujinAst::Application(lhs, rhs) => DeBrujinAst::Application(
Box::new(lhs.subst_bound(depth, subst.clone())), Box::new(lhs.subst_bound(depth, subst.clone())),
Box::new(rhs.subst_bound(depth, subst)), Box::new(rhs.subst_bound(depth, subst)),

@ -1,27 +1,8 @@
use crate::{Ast, Constant, PrimitiveType, Type, exec::DeBrujinAst as DBAst}; use crate::{
PrimitiveType, Type,
#[test] exec::DeBrujinAst as DBAst,
fn beta_reduce() { parse::{Ast, Constant},
let input = Ast::Application( };
Box::new(Ast::Abstraction(
"x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
)),
Box::new(Ast::Variable("y".to_string())),
);
let reduced = input.beta_reduce();
assert_eq!(
reduced,
Ast::Application(
Box::new(Ast::Variable("y".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
),
)
}
#[test] #[test]
fn to_de_brujin_ast_simple() { fn to_de_brujin_ast_simple() {
@ -34,12 +15,18 @@ fn to_de_brujin_ast_simple() {
Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Variable("x".to_string())),
)), )),
); );
let de_brujin = input.to_de_brujin(); let de_brujin: DBAst = input.into();
assert_eq!( assert_eq!(
de_brujin, de_brujin,
DBAst::Abstraction(Box::new(DBAst::Abstraction(Box::new( DBAst::Abstraction(
DBAst::BoundVariable(1) "x".to_string(),
)))) PrimitiveType::Nat.into(),
Box::new(DBAst::Abstraction(
"x".to_string(),
PrimitiveType::Nat.into(),
Box::new(DBAst::BoundVariable(1))
))
)
) )
} }
@ -56,7 +43,7 @@ fn de_brujin_beta_reduce() {
)), )),
Box::new(Ast::Variable("y".to_string())), Box::new(Ast::Variable("y".to_string())),
); );
let dbast = input.to_de_brujin(); let dbast: DBAst = input.into();
let reduced = dbast.beta_reduce(); let reduced = dbast.beta_reduce();
assert_eq!( assert_eq!(
reduced, reduced,
@ -66,3 +53,21 @@ fn de_brujin_beta_reduce() {
), ),
) )
} }
#[test]
fn to_and_from_de_brujin_is_id() {
let input = Ast::Application(
Box::new(Ast::Abstraction(
"x".to_string(),
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat),
Box::new(Ast::Application(
Box::new(Ast::Variable("x".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
)),
Box::new(Ast::Variable("y".to_string())),
);
let dbast: DBAst = input.clone().into();
let output: Ast = dbast.into();
assert_eq!(input, output);
}

@ -1,37 +1,75 @@
use std::{collections::HashMap, error::Error, rc::Rc}; use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc};
use crate::{Ast, Constant, Ident, PrimitiveType, Type}; use crate::{
DeBrujinAst, Ident, PrimitiveType, Type,
parse::{Ast, Constant},
types::{TaggedType, TypeTag},
vec_map::VecMap,
};
#[derive(Debug)] #[derive(Debug)]
pub enum InferError { pub enum InferError {
NotAFunction, NotAFunction,
MismatchedType, MismatchedType,
NotInContext, NotInContext,
ExpectedConreteType,
ExpectedTypeWithTag,
} }
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub fn infer_type(mut gamma: Rc<HashMap<Ident, Type>>, ast: Ast) -> Result<Type, InferError> { pub fn infer_type(
gamma: &HashMap<Ident, TaggedType>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
infer_type_debrujin_int(gamma, Rc::new(VecMap::new()), ast)
}
fn infer_type_debrujin_int(
gamma_free: &HashMap<Ident, TaggedType>,
mut gamma_bound: Rc<VecMap<usize, TaggedType>>,
ast: DeBrujinAst,
) -> Result<TaggedType, InferError> {
match ast { match ast {
Ast::Abstraction(arg, arg_type, ast) => { DeBrujinAst::Abstraction(_, arg_type, ast) => {
Rc::make_mut(&mut gamma).insert(arg, arg_type.clone()); let gamma_ref = Rc::make_mut(&mut gamma_bound);
let out_type = infer_type(gamma, *ast)?; gamma_ref.map_keys(|i| *i += 1);
Ok(Type::Arrow(Box::new(arg_type), Box::new(out_type))) gamma_ref.insert(1, arg_type.clone().into());
let out_type = infer_type_debrujin_int(gamma_free, gamma_bound, *ast)?;
let typ = Type::Arrow(
Box::new(arg_type),
Box::new(
out_type
.to_concrete()
.ok_or(InferError::ExpectedConreteType)?,
),
);
Ok(typ.into())
} }
Ast::Application(left, right) => { DeBrujinAst::Application(lhs, rhs) => {
let left_type = infer_type(gamma.clone(), *left)?; let left_type = infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *lhs)?;
let Type::Arrow(in_type, out_type) = left_type else { let Some((in_type, out_type)) = left_type.arrow() else {
return Err(InferError::NotAFunction); return Err(InferError::NotAFunction);
}; };
let right_type = infer_type(gamma, *right)?; let Some(right_type) =
if *in_type != right_type { infer_type_debrujin_int(gamma_free, gamma_bound.clone(), *rhs)?.to_concrete()
return Err(InferError::MismatchedType); else {
} return Err(InferError::ExpectedConreteType);
Ok(*out_type) };
let out_type = match in_type.matches(&right_type) {
(true, None) => out_type,
(true, Some(a)) => out_type.specialize(&a, &right_type),
(false, _) => return Err(InferError::MismatchedType),
};
Ok(out_type)
} }
Ast::Variable(var) => gamma.get(&var).cloned().ok_or(InferError::NotInContext), DeBrujinAst::FreeVariable(x) => gamma_free.get(&x).cloned().ok_or(InferError::NotInContext),
Ast::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat)), // compiler bug if not present
Ast::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool)), DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()),
DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()),
DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()),
DeBrujinAst::Constant(_) => unreachable!(),
} }
} }

@ -1,49 +1,73 @@
use std::{collections::HashMap, rc::Rc}; use std::{collections::HashMap, rc::Rc};
use crate::{Ast, Constant, PrimitiveType, Type}; use crate::{
DeBrujinAst, PrimitiveType, Type,
use super::infer_type; inference::infer_type,
parse::{Ast, Constant},
types::{TaggedType, TypeTag},
};
#[test] #[test]
fn infer_id_type() { fn infer_add_nat() {
let ast = Ast::Abstraction( let ast: DeBrujinAst = Ast::Application(
"x".to_string(), Box::new(Ast::Application(
Type::Primitive(PrimitiveType::Nat), Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Variable("x".to_string())), Box::new(Ast::Constant(Constant::Nat(5))),
)),
Box::new(Ast::Constant(Constant::Nat(7))),
)
.into();
let mut gamma = HashMap::new();
gamma.insert(
"add".to_string(),
crate::types::TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::Arrow(
Box::new(Type::Generic("a".to_string())),
Box::new(Type::Arrow(
Box::new(Type::Generic("a".to_string())),
Box::new(Type::Generic("a".to_string())),
)),
))),
),
); );
let infered = infer_type(Rc::new(HashMap::new()), ast).unwrap(); let infered = infer_type(&gamma, ast).unwrap();
assert_eq!( assert_eq!(
infered, infered,
Type::Arrow( TaggedType::Concrete(Type::Primitive(PrimitiveType::Nat))
Box::new(Type::Primitive(PrimitiveType::Nat)), );
Box::new(Type::Primitive(PrimitiveType::Nat))
)
)
} }
#[test] #[test]
fn infer_addition_result_type() { fn infer_add_nat_partial() {
let ast = Ast::Application( let ast: DeBrujinAst = Ast::Application(
Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())), Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))), Box::new(Ast::Constant(Constant::Nat(5))),
)), )
Box::new(Ast::Constant(Constant::Nat(7))), .into();
);
let mut gamma = HashMap::new(); let mut gamma = HashMap::new();
gamma.insert( gamma.insert(
"add".to_string(), "add".to_string(),
Type::Arrow( crate::types::TaggedType::Tagged(
Box::new(Type::Primitive(PrimitiveType::Nat)), TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::Arrow(
Box::new(Type::Generic("a".to_string())),
Box::new(Type::Arrow( Box::new(Type::Arrow(
Box::new(Type::Primitive(PrimitiveType::Nat)), Box::new(Type::Generic("a".to_string())),
Box::new(Type::Primitive(PrimitiveType::Nat)), Box::new(Type::Generic("a".to_string())),
)), )),
))),
), ),
); );
let infered = infer_type(Rc::new(gamma), ast).unwrap(); let infered = infer_type(&gamma, ast).unwrap();
assert_eq!(infered, Type::Primitive(PrimitiveType::Nat)); assert_eq!(
infered,
TaggedType::Concrete(Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat))
);
} }

@ -5,84 +5,10 @@ use std::fmt::Display;
mod exec; mod exec;
mod inference; mod inference;
mod parse; mod parse;
mod types;
mod vec_map;
pub use exec::DeBrujinAst;
pub use inference::infer_type; pub use inference::infer_type;
pub use parse::{ParseError, is_ident, parse, parse_type, sexpr::parse_string}; pub use parse::{Ast, ParseError, is_ident, parse, parse_type, sexpr::parse_string};
use types::{Ident, PrimitiveType, Type};
type Ident = String;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PrimitiveType {
Nat,
Bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
Primitive(PrimitiveType), // Bool | Nat
Arrow(Box<Type>, Box<Type>), // 0 -> 1
}
impl From<PrimitiveType> for Type {
fn from(value: PrimitiveType) -> Self {
Type::Primitive(value)
}
}
impl Type {
fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self {
Self::Arrow(Box::new(t1.into()), Box::new(t2.into()))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Constant {
Nat(usize),
Bool(bool),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ast {
Abstraction(Ident, Type, Box<Ast>), // \0:1.2
Application(Box<Ast>, Box<Ast>), // 0 1
Variable(Ident), // x
Constant(Constant), // true | false | n
}
impl Display for Ast {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"),
Ast::Application(lhs, rhs) => write!(f, "({lhs} {rhs})"),
Ast::Variable(v) => write!(f, "{v}"),
Ast::Constant(constant) => write!(f, "{constant}"),
}
}
}
impl Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Primitive(primitive_type) => write!(f, "{primitive_type}"),
Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"),
}
}
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::Nat => write!(f, "Nat"),
PrimitiveType::Bool => write!(f, "Bool"),
}
}
}
impl Display for Constant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Constant::Nat(n) => write!(f, "{n}"),
Constant::Bool(b) => write!(f, "{b}"),
}
}
}

@ -4,7 +4,9 @@ use std::{
rc::Rc, rc::Rc,
}; };
use stlc_type_inference::{infer_type, is_ident, parse, parse_string, parse_type}; use stlc_type_inference::{
Ast, DeBrujinAst, infer_type, is_ident, parse, parse_string, parse_type,
};
macro_rules! repl_err { macro_rules! repl_err {
($err:expr) => {{ ($err:expr) => {{
@ -16,7 +18,7 @@ macro_rules! repl_err {
} }
fn main() { fn main() {
let mut gamma = Rc::new(HashMap::new()); let mut gamma = HashMap::new();
print!("> "); print!("> ");
stdout().flush().unwrap(); stdout().flush().unwrap();
for line in stdin().lines() { for line in stdin().lines() {
@ -25,7 +27,7 @@ fn main() {
if let Some((cmd, expr)) = tl.split_once(' ') { if let Some((cmd, expr)) = tl.split_once(' ') {
match cmd { match cmd {
"t" => match parse(expr) { "t" => match parse(expr) {
Ok(a) => match infer_type(gamma.clone(), a) { Ok(a) => match infer_type(&gamma, a.into()) {
Ok(t) => println!("{t}"), Ok(t) => println!("{t}"),
Err(e) => repl_err!("Could not infer type {e:?}"), Err(e) => repl_err!("Could not infer type {e:?}"),
}, },
@ -49,7 +51,7 @@ fn main() {
}; };
println!("Added {ident} with type {typ} to the context"); println!("Added {ident} with type {typ} to the context");
Rc::make_mut(&mut gamma).insert(ident, typ); gamma.insert(ident, typ.into());
} }
} }
c => println!("Unknown command {c}"), c => println!("Unknown command {c}"),
@ -67,11 +69,13 @@ fn main() {
Err(e) => repl_err!("Parse error {e:?}"), Err(e) => repl_err!("Parse error {e:?}"),
}; };
let typ = match infer_type(gamma.clone(), ast.clone()) { let ast: DeBrujinAst = ast.into();
let typ = match infer_type(&gamma, ast.clone()) {
Ok(t) => t, Ok(t) => t,
Err(e) => repl_err!("Could not infer type {e:?}"), Err(e) => repl_err!("Could not infer type {e:?}"),
}; };
let ast = ast.beta_reduce(); let ast = ast.beta_reduce();
let ast: Ast = ast.into();
println!("{ast} : {typ}") println!("{ast} : {typ}")
} }
print!("> "); print!("> ");

@ -1,6 +1,8 @@
use std::fmt::Display;
use sexpr::{Sexpr, parse_string}; use sexpr::{Sexpr, parse_string};
use crate::{Ast, Constant, PrimitiveType, Type}; use crate::{PrimitiveType, Type, types::Ident};
pub mod sexpr; pub mod sexpr;
#[cfg(test)] #[cfg(test)]
@ -27,6 +29,40 @@ pub enum ParseError {
ExpectedOneOf(Vec<String>, String), ExpectedOneOf(Vec<String>, String),
} }
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Constant {
Nat(usize),
Bool(bool),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Ast {
Abstraction(Ident, Type, Box<Ast>), // \0:1.2
Application(Box<Ast>, Box<Ast>), // 0 1
Variable(Ident), // x
Constant(Constant), // true | false | n
}
impl Display for Ast {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Ast::Abstraction(var, typ, ast) => write!(f, "(\\ ({var} {typ}) ({ast}))"),
Ast::Application(lhs, rhs) => write!(f, "({lhs} {rhs})"),
Ast::Variable(v) => write!(f, "{v}"),
Ast::Constant(constant) => write!(f, "{constant}"),
}
}
}
impl Display for Constant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Constant::Nat(n) => write!(f, "{n}"),
Constant::Bool(b) => write!(f, "{b}"),
}
}
}
fn expect_symbol(ast: Option<Sexpr>) -> Result<String, ParseError> { fn expect_symbol(ast: Option<Sexpr>) -> Result<String, ParseError> {
match ast { match ast {
Some(Sexpr::Symbol(s)) => Ok(s), Some(Sexpr::Symbol(s)) => Ok(s),

@ -1,6 +1,9 @@
use crate::{ use crate::{
Ast, Constant, PrimitiveType, Type, PrimitiveType, Type,
parse::sexpr::{Sexpr, parse_string}, parse::{
Ast, Constant,
sexpr::{Sexpr, parse_string},
},
}; };
use super::{parse, parse_type}; use super::{parse, parse_type};

@ -0,0 +1,211 @@
use std::{fmt::Display, str};
pub type Ident = String;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PrimitiveType {
Nat,
Bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeTag {
Num,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TaggedType {
Tagged(TypeTag, Ident, Box<TaggedType>),
Concrete(Type),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Type {
Generic(Ident), // a
Primitive(PrimitiveType), // Bool | Nat
Arrow(Box<Type>, Box<Type>), // 0 -> 1
}
impl From<Type> for TaggedType {
fn from(value: Type) -> Self {
Self::Concrete(value)
}
}
impl TaggedType {
pub fn r#type(&self) -> &Type {
match self {
TaggedType::Tagged(_, _, tagged_type) => tagged_type.r#type(),
TaggedType::Concrete(t) => t,
}
}
pub fn type_mut(&mut self) -> &mut Type {
match self {
TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_mut(),
TaggedType::Concrete(t) => t,
}
}
pub fn to_concrete(self) -> Option<Type> {
match self {
TaggedType::Tagged(type_tag, _, tagged_type) => None,
TaggedType::Concrete(t) => {
if t.is_concrete() {
Some(t)
} else {
None
}
}
}
}
pub fn matches(&self, typ: &Type) -> (bool, Option<Ident>) {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
if typ.has_tag(type_tag) {
(true, Some(ident.clone()))
} else {
tagged_type.matches(typ)
}
}
TaggedType::Concrete(c) => (c == typ, None),
}
}
pub fn arrow(self) -> Option<(TaggedType, TaggedType)> {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
let (lhs, rhs) = tagged_type.arrow()?;
Some((
TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs))
.clear_unused_names(),
TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(),
))
}
TaggedType::Concrete(Type::Arrow(lhs, rhs)) => {
Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs)))
}
TaggedType::Concrete(_) => None,
}
}
pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType {
match self {
TaggedType::Tagged(_, i, tagged_type) if i == ident => {
tagged_type.specialize(ident, typ)
}
TaggedType::Tagged(t, i, tagged_type) => {
TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ)))
}
TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)),
}
}
fn name_used(&self, ident: &str) -> bool {
match self {
TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident),
TaggedType::Concrete(c) => c.name_used(ident),
}
}
fn clear_unused_names(self) -> TaggedType {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
if tagged_type.name_used(&ident) {
TaggedType::Tagged(type_tag, ident, tagged_type)
} else {
*tagged_type
}
}
t => t,
}
}
}
impl From<PrimitiveType> for Type {
fn from(value: PrimitiveType) -> Self {
Self::Primitive(value)
}
}
impl Type {
pub fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self {
Self::Arrow(Box::new(t1.into()), Box::new(t2.into()))
}
pub fn is_concrete(&self) -> bool {
match self {
Type::Generic(_) => false,
Type::Primitive(primitive_type) => true,
Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(),
}
}
pub fn has_tag(&self, tag: &TypeTag) -> bool {
match self {
Type::Generic(_) => false,
Type::Primitive(primitive_type) => match (primitive_type, tag) {
(PrimitiveType::Nat, TypeTag::Num) => true,
_ => false,
},
Type::Arrow(_, _) => false,
}
}
fn name_used(&self, ident: &str) -> bool {
match self {
Type::Generic(i) if *i == ident => true,
_ => false,
}
}
fn specialize(self, ident: &str, typ: &Type) -> Type {
match self {
Type::Generic(i) if i == ident => typ.clone(),
Type::Arrow(lhs, rhs) => Type::Arrow(
Box::new(lhs.specialize(ident, typ)),
Box::new(rhs.specialize(ident, typ)),
),
t => t,
}
}
}
impl Display for TypeTag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TypeTag::Num => write!(f, "Num"),
}
}
}
impl Display for TaggedType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaggedType::Tagged(type_tag, ident, tagged_type) => {
write!(f, "{type_tag} {ident} => {tagged_type}")
}
TaggedType::Concrete(t) => write!(f, "{t}"),
}
}
}
impl Display for Type {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Type::Generic(i) => write!(f, "{i}"),
Type::Primitive(primitive_type) => write!(f, "{primitive_type}"),
Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"),
}
}
}
impl Display for PrimitiveType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimitiveType::Nat => write!(f, "Nat"),
PrimitiveType::Bool => write!(f, "Bool"),
}
}
}

@ -0,0 +1,38 @@
use std::mem;
pub struct VecMap<K: Eq, V> {
map: Vec<(K, V)>,
}
impl<K: Eq, V> VecMap<K, V> {
pub fn new() -> Self {
Self { map: Vec::new() }
}
pub fn insert(&mut self, key: K, val: V) -> Option<(K, V)> {
if let Some(entry) = self.map.iter_mut().find(|(k, _)| *k == key) {
Some(mem::replace(entry, (key, val)))
} else {
self.map.push((key, val));
None
}
}
pub fn get(&self, key: &K) -> Option<&V> {
self.map
.iter()
.find_map(|(k, v)| if k == key { Some(v) } else { None })
}
pub fn map_keys<F: Fn(&mut K)>(&mut self, f: F) {
self.map.iter_mut().for_each(|(k, _)| f(k))
}
}
impl<K: Eq + Clone, V: Clone> Clone for VecMap<K, V> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
Loading…
Cancel
Save