diff --git a/src/exec/builtins.rs b/src/exec/builtins.rs index 31a690b..1096306 100644 --- a/src/exec/builtins.rs +++ b/src/exec/builtins.rs @@ -1,8 +1,8 @@ -use std::{collections::HashMap, fmt::Debug, rc::Rc}; +use std::{collections::HashMap, rc::Rc}; use crate::{ - parse::Constant, - types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, + parse::{Constant, Ident}, + types::TaggedType, }; use super::DeBrujinAst; @@ -105,14 +105,7 @@ impl DeBrujinBuiltInAst { } fn is_value(&self) -> bool { - match self { - DeBrujinBuiltInAst::Abstraction(_, _, _) => true, - DeBrujinBuiltInAst::Application(_, _) => false, - DeBrujinBuiltInAst::FreeVariable(_) => true, - DeBrujinBuiltInAst::BoundVariable(_) => true, - DeBrujinBuiltInAst::Constant(constant) => true, - DeBrujinBuiltInAst::Builtin(builtin) => true, - } + !matches!(self, DeBrujinBuiltInAst::Application(_, _)) } } diff --git a/src/exec/mod.rs b/src/exec/mod.rs index a6c5303..a5bbe43 100644 --- a/src/exec/mod.rs +++ b/src/exec/mod.rs @@ -9,8 +9,8 @@ pub use builtins::Builtin; use std::{collections::HashMap, rc::Rc}; use crate::{ - parse::{Ast, Constant}, - types::{Ident, TaggedType}, + parse::{Ast, Constant, Ident}, + types::TaggedType, vec_map::VecMap, }; diff --git a/src/inference/mod.rs b/src/inference/mod.rs index 5dc838c..668656d 100644 --- a/src/inference/mod.rs +++ b/src/inference/mod.rs @@ -1,5 +1,3 @@ -use std::{collections::HashMap, convert::Infallible, error::Error, rc::Rc}; - use crate::types::{TaggedType, TypeTag}; pub mod recursive; @@ -13,9 +11,7 @@ pub enum InferError { MismatchedType, NotInContext, ExpectedConreteType, - ExpectedTypeWithTag, DoesNotFitTag(TypeTag, TaggedType), - ConfilictingBind, ConfilictingTags, UnusableConstraint, } diff --git a/src/inference/recursive.rs b/src/inference/recursive.rs index a70be18..8ec513a 100644 --- a/src/inference/recursive.rs +++ b/src/inference/recursive.rs @@ -2,13 +2,14 @@ use std::{collections::HashMap, rc::Rc}; use crate::{ DeBrujinAst, - parse::Constant, - types::{Ident, PrimitiveType, TaggedType, Type}, + parse::{Constant, Ident}, + types::{PrimitiveType, TaggedType, Type}, vec_map::VecMap, }; use super::InferError; +/// Old algorithm use unification instead pub fn infer_type( gamma: &HashMap, ast: DeBrujinAst, @@ -54,6 +55,8 @@ fn infer_type_debrujin_int( DeBrujinAst::BoundVariable(n) => Ok(gamma_bound.get(&n).cloned().unwrap()), DeBrujinAst::Constant(Constant::Nat(_)) => Ok(Type::Primitive(PrimitiveType::Nat).into()), DeBrujinAst::Constant(Constant::Bool(_)) => Ok(Type::Primitive(PrimitiveType::Bool).into()), - DeBrujinAst::Constant(_) => unreachable!(), + DeBrujinAst::Constant(Constant::Float(_)) => { + Ok(Type::Primitive(PrimitiveType::Float).into()) + } } } diff --git a/src/inference/test.rs b/src/inference/test.rs index 30eacd7..a96670a 100644 --- a/src/inference/test.rs +++ b/src/inference/test.rs @@ -3,16 +3,14 @@ use std::{cell::RefCell, collections::HashMap, rc::Rc}; use crate::{ DeBrujinAst, inference::{ - recursive::infer_type as infer_type_rec, unification::Constraints, - unification::infer_type as infer_type_uni, + recursive::infer_type as infer_type_rec, unification::infer_type as infer_type_uni, }, multi_map::MultiMap, parse::{Ast, Constant}, types::{PrimitiveType, TaggedType, Type, TypeTag}, - vec_map::VecMap, }; -use super::unification::{TypeVarAst, TypeVarCtx}; +use super::unification::TypeVarCtx; #[test] fn infer_add_nat_rec() { @@ -87,7 +85,7 @@ fn subtype_constraints() { Box::new(Type::arrow("a", "a").into()), ); - let mut st_constraints = Rc::new(RefCell::new(MultiMap::new())); + let st_constraints = Rc::new(RefCell::new(MultiMap::new())); let ctx = TypeVarCtx::new(); let res = typ.make_constraints(st_constraints.clone(), &ctx); @@ -167,7 +165,7 @@ fn infer_id_uni() { ) .into(); - let mut gamma = HashMap::new(); + let gamma = HashMap::new(); let typ = infer_type_uni(&gamma, ast).unwrap(); assert_eq!( @@ -182,69 +180,11 @@ fn infer_id_uni() { #[test] fn subst_type_var() { - let typ = TaggedType::Tagged( - TypeTag::Num, - "a".to_string(), - Box::new(TaggedType::Concrete(Type::arrow("a", "?typ_4"))), - ); - - let subst = TaggedType::Tagged( - TypeTag::Any, - "b".to_string(), - Box::new(TaggedType::Concrete(Type::arrow("b", "b"))), - ); - - let typ = typ.subst_typevar("?typ_4", 0, subst).unwrap(); - - assert_eq!( - typ, - TaggedType::Tagged( - TypeTag::Num, - "a".to_string(), - Box::new(TaggedType::Tagged( - TypeTag::Any, - "b".to_string(), - Box::new(TaggedType::Concrete(Type::arrow( - "a", - Type::arrow("b", "b") - ))) - )) - ) - ); + let typ = Type::arrow("a", "?typ_4"); - let typ = TaggedType::Tagged( - TypeTag::Num, - "a".to_string(), - Box::new(TaggedType::Concrete(Type::arrow("a", "?typ_4"))), - ); + let subst = Type::arrow("b", "b"); - let subst = TaggedType::Tagged( - TypeTag::Any, - "a".to_string(), - Box::new(TaggedType::Concrete(Type::arrow("a", "a"))), - ); + let typ = typ.subst_typevar("?typ_4", &subst); - let typ = typ.subst_typevar("?typ_4", 0, subst).unwrap(); - - assert!(match typ { - TaggedType::Tagged(TypeTag::Num, a, tagged_type) => match *tagged_type { - TaggedType::Tagged(TypeTag::Any, b, tagged_type) => match *tagged_type { - TaggedType::Concrete(typ) => match typ { - Type::Arrow(lhs, rhs) => match (*lhs, *rhs) { - (Type::Variable(a1), Type::Arrow(b1, b2)) if a1 == a => { - match (*b1, *b2) { - (Type::Variable(b1), Type::Variable(b2)) => (b1 == b) && (b2 == b), - _ => false, - } - } - _ => false, - }, - _ => false, - }, - _ => false, - }, - _ => false, - }, - _ => false, - }); + assert_eq!(typ, (Type::arrow("a", Type::arrow("b", "b")))); } diff --git a/src/inference/unification.rs b/src/inference/unification.rs index 9f58e68..60eb38b 100644 --- a/src/inference/unification.rs +++ b/src/inference/unification.rs @@ -1,16 +1,15 @@ use std::{ cell::RefCell, - clone, collections::HashMap, fmt::{Debug, DebugMap}, rc::Rc, }; use crate::{ - Ast, DeBrujinAst, + DeBrujinAst, multi_map::MultiMap, - parse::Constant, - types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, + parse::{Constant, Ident}, + types::{PrimitiveType, TaggedType, Type, TypeTag}, vec_map::VecMap, }; @@ -91,8 +90,6 @@ pub(super) fn step_1( )) } DeBrujinAst::Abstraction(i, None, ast) => { - let var = ctx.get_var(); - let typ = ctx.get_var(); let gamma_ref = Rc::make_mut(&mut gamma_bound); gamma_ref.map_keys(|i| *i += 1); @@ -187,16 +184,14 @@ pub(super) fn step_2(mut eq_constraints: Constraints) -> Result) -> Result Result + fn subst(self, var: &str, subst: &Type) -> Self where Self: Sized; } impl Subst for TypeVarAst { - fn subst(self, var: &str, subst: &Type) -> Result { + fn subst(self, var: &str, subst: &Type) -> TypeVarAst { match self { - TypeVarAst::Abstraction(type1, ident, type2, ast) => Ok(TypeVarAst::Abstraction( - type1.subst_typevar(var, subst)?, + TypeVarAst::Abstraction(type1, ident, type2, ast) => TypeVarAst::Abstraction( + type1.subst_typevar(var, subst), ident, - type2.subst_typevar(var, subst)?, - Box::new(ast.subst(var, subst)?), - )), - TypeVarAst::Application(typ, lhs, rhs) => Ok(TypeVarAst::Application( - typ.subst_typevar(var, subst)?, - Box::new(lhs.subst(var, subst)?), - Box::new(rhs.subst(var, subst)?), - )), + type2.subst_typevar(var, subst), + Box::new(ast.subst(var, subst)), + ), + TypeVarAst::Application(typ, lhs, rhs) => TypeVarAst::Application( + typ.subst_typevar(var, subst), + Box::new(lhs.subst(var, subst)), + Box::new(rhs.subst(var, subst)), + ), TypeVarAst::FreeVariable(typ, x) => { - Ok(TypeVarAst::FreeVariable(typ.subst_typevar(var, subst)?, x)) + TypeVarAst::FreeVariable(typ.subst_typevar(var, subst), x) } TypeVarAst::BoundVariable(typ, i) => { - Ok(TypeVarAst::BoundVariable(typ.subst_typevar(var, subst)?, i)) + TypeVarAst::BoundVariable(typ.subst_typevar(var, subst), i) + } + TypeVarAst::Constant(typ, constant) => { + TypeVarAst::Constant(typ.subst_typevar(var, subst), constant) } - TypeVarAst::Constant(typ, constant) => Ok(TypeVarAst::Constant( - typ.subst_typevar(var, subst)?, - constant, - )), } } } impl Subst for Constraints { - fn subst(mut self, var: &str, subst: &Type) -> Result, InferError> { + fn subst(self, var: &str, subst: &Type) -> Constraints { self.into_iter() - .map::, _>(|(k, v)| { - Ok((k.subst_typevar(var, subst)?, v)) - }) + .map(|(k, v)| (k.subst_typevar(var, subst), v)) .collect() } } @@ -281,10 +273,10 @@ impl SubstFn { } } - pub fn apply(&self, target: T) -> Result { + pub fn apply(&self, target: T) -> T { let target = target.subst(&self.var, &self.subst); if let Some(s) = &self.then { - s.apply(target?) + s.apply(target) } else { target } @@ -317,21 +309,21 @@ pub fn infer_type( )?; let res = step_2(eq_constraints.take())?.unwrap(); - let ast = res.apply(ast)?; + let ast = res.apply(ast); fn get_type(ast: TypeVarAst) -> Type { match ast { TypeVarAst::Abstraction(_, _, typ, ast) => Type::arrow(typ, get_type(*ast)), - TypeVarAst::Application(typ, _, _) => typ, - TypeVarAst::FreeVariable(typ, _) => typ, - TypeVarAst::BoundVariable(typ, _) => typ, - TypeVarAst::Constant(typ, constant) => typ, + TypeVarAst::Application(typ, _, _) + | TypeVarAst::FreeVariable(typ, _) + | TypeVarAst::BoundVariable(typ, _) + | TypeVarAst::Constant(typ, _) => typ, } } let typ = get_type(ast); - let st_constraints = res.apply(st_constraints.take())?; + let st_constraints = res.apply(st_constraints.take()); let st_constraints = st_constraints .into_iter() .map(|(typ, tag)| { @@ -374,7 +366,7 @@ impl TaggedType { tagged_type.map_name(|n| { if *n == ident { - *n = name.clone() + n.clone_from(&name); } }); @@ -385,67 +377,17 @@ impl TaggedType { TaggedType::Concrete(c) => c, } } - - pub fn subst_typevar( - self, - var: &str, - diversifier: usize, - mut subst: TaggedType, - ) -> Result { - match self { - TaggedType::Tagged(t, i, tagged_type) => { - if subst.name_used(&i) { - subst.map_name(|f| { - if *f == i { - *f = format!("{i}+{diversifier}") - } - }); - } - Ok(TaggedType::Tagged( - t, - i, - Box::new(tagged_type.subst_typevar(var, diversifier + 1, subst)?), - )) - } - TaggedType::Concrete(t) => subst.subst_into(var, diversifier, t), - } - } - - fn subst_into( - self, - var: &str, - diversifier: usize, - mut target: Type, - ) -> Result { - match self { - TaggedType::Tagged(t, i, tagged_type) => { - if target.name_used(&i) { - target.map_name(&|f| { - if *f == i { - *f = format!("{i}+{diversifier}") - } - }); - } - Ok(TaggedType::Tagged( - t, - i, - Box::new(tagged_type.subst_into(var, diversifier + 1, target)?), - )) - } - TaggedType::Concrete(c) => Ok(TaggedType::Concrete(target.subst_typevar(var, &c)?)), - } - } } impl Type { - pub fn subst_typevar(self, var: &str, subst: &Type) -> Result { + pub fn subst_typevar(self, var: &str, subst: &Type) -> Type { match self { - Type::Variable(v) if v == var => Ok(subst.clone()), - Type::Arrow(lhs, rhs) => Ok(Type::Arrow( - Box::new(lhs.subst_typevar(var, subst)?), - Box::new(rhs.subst_typevar(var, subst)?), - )), - t => Ok(t), + Type::Variable(v) if v == var => subst.clone(), + Type::Arrow(lhs, rhs) => Type::Arrow( + Box::new(lhs.subst_typevar(var, subst)), + Box::new(rhs.subst_typevar(var, subst)), + ), + t => t, } } } diff --git a/src/lib.rs b/src/lib.rs index ef41667..5c0f9c9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,3 @@ -#![allow(unused)] - -use std::fmt::Display; - mod exec; mod inference; mod multi_map; @@ -12,5 +8,4 @@ mod vec_map; pub use exec::{Builtin, DeBrujinAst, builtin}; pub use inference::recursive::infer_type as infer_type_rec; pub use inference::unification::infer_type as infer_type_uni; -use lalrpop_util::lalrpop_mod; pub use parse::{Ast, parse_ast_str, parse_type_str}; diff --git a/src/main.rs b/src/main.rs index 9f26080..c590c98 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,7 +78,7 @@ fn main() { }; let ast = ast.reduce(&builtins); let ast: Ast = ast.into(); - println!("{ast} : {typ}") + println!("{ast} : {typ}"); } print!("> "); stdout().flush().unwrap(); @@ -86,5 +86,5 @@ fn main() { } fn is_ident(s: &str) -> bool { - s.starts_with(|c: char| c.is_alphabetic()) && s.chars().all(|c| c.is_alphanumeric()) + s.starts_with(char::is_alphabetic) && s.chars().all(char::is_alphabetic) } diff --git a/src/multi_map.rs b/src/multi_map.rs index 40d1508..099d8f1 100644 --- a/src/multi_map.rs +++ b/src/multi_map.rs @@ -16,16 +16,18 @@ impl MultiMap { pub fn get(&self, key: &K) -> Vec<&V> { self.map .iter() - .filter_map(|(k, v)| if k == key { Some(v) } else { None }) + .filter_map(|(k, v)| (k == key).then_some(v)) .collect() } pub fn map_keys(&mut self, f: F) { - self.map.iter_mut().for_each(|(k, _)| f(k)) + self.map.iter_mut().for_each(|(k, _)| f(k)); } - pub fn map(&mut self, f: F) { - self.map.iter_mut().for_each(f); + pub fn map (K, V)>(&mut self, f: F) { + let vec = mem::take(&mut self.map); + let vec = vec.into_iter().map(f).collect::>(); + self.map = vec; } pub fn try_map Result<(K, V), E>>(&mut self, f: F) -> Result<(), E> { @@ -44,7 +46,7 @@ impl MultiMap { .map .iter() .enumerate() - .find_map(|(n, ref kv)| if f(kv) { Some(n) } else { None })?; + .find_map(|(n, ref kv)| f(kv).then_some(n))?; Some(self.map.swap_remove(idx)) } diff --git a/src/parse/mod.rs b/src/parse/mod.rs index 0e93bed..0a12991 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -4,17 +4,17 @@ use std::{ str::ParseBoolError, }; -use crate::types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}; +use crate::types::{PrimitiveType, TaggedType, Type, TypeTag}; use lalrpop_util::{ParseError as LALRPopError, lalrpop_mod}; use tokenize::Lexer; #[cfg(test)] mod test; - mod tokenize; - lalrpop_mod!(parser); +pub type Ident = String; + #[derive(Default, Debug, Clone, PartialEq)] pub enum ParseError { UnknownTypeTag(String), @@ -71,8 +71,8 @@ impl TryFrom for TypeTag { } impl Type { - pub fn from_name(name: String) -> Self { - match name.as_str() { + pub fn from_name(name: &str) -> Self { + match name { "Nat" => Self::Primitive(PrimitiveType::Nat), "Float" => Self::Primitive(PrimitiveType::Float), "Bool" => Self::Primitive(PrimitiveType::Bool), diff --git a/src/parse/test.rs b/src/parse/test.rs index c01ed64..3b0bf9e 100644 --- a/src/parse/test.rs +++ b/src/parse/test.rs @@ -1,8 +1,5 @@ use crate::{ - parse::{ - Ast, Constant, - tokenize::{self, Lexer}, - }, + parse::{Ast, Constant, tokenize::Lexer}, types::{PrimitiveType, TaggedType, Type, TypeTag}, }; diff --git a/src/parse/tokenize.rs b/src/parse/tokenize.rs index c502f80..1c34b28 100644 --- a/src/parse/tokenize.rs +++ b/src/parse/tokenize.rs @@ -85,9 +85,9 @@ pub enum Token { ParenClose, } -pub type Spanned = Result<(Loc, Tok, Loc), Error>; +pub(super) type Spanned = Result<(Loc, Tok, Loc), Error>; -pub struct Lexer<'input> { +pub(super) struct Lexer<'input> { // instead of an iterator over characters, we have a token iterator token_stream: SpannedIter<'input, Token>, } diff --git a/src/parser.lalrpop b/src/parser.lalrpop index 0645e2b..d3f370b 100644 --- a/src/parser.lalrpop +++ b/src/parser.lalrpop @@ -1,6 +1,6 @@ use crate::{ parse::{Ast, Constant, tokenize as lexer, ParseError as PError}, - types::{TaggedType, TypeTag, Type, PrimitiveType}, + types::{TaggedType, TypeTag, Type}, }; use lalrpop_util::ParseError; @@ -76,7 +76,7 @@ TypeTag: TypeTag = Ident =>? TypeTag::try_from(<>).map_err(|e| ParseError::User{ Type: Type = { BasicType, ArrowType }; -BasicType: Type = Ident => Type::from_name(<>); +BasicType: Type = Ident => Type::from_name(&<>); ArrowType: Type = { #[precedence(level="0")] diff --git a/src/types/concrete.rs b/src/types/concrete.rs new file mode 100644 index 0000000..29b7ed1 --- /dev/null +++ b/src/types/concrete.rs @@ -0,0 +1,138 @@ +use std::{collections::HashSet, fmt::Display}; + +use crate::parse::Ident; + +use super::tagged::TypeTag; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PrimitiveType { + Nat, + Bool, + Float, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Type { + Variable(Ident), // a + Primitive(PrimitiveType), // Bool | Nat + Arrow(Box, Box), // 0 -> 1 +} + +impl From for Type { + fn from(value: PrimitiveType) -> Self { + Self::Primitive(value) + } +} + +impl From<&str> for Type { + fn from(value: &str) -> Self { + Self::Variable(value.to_string()) + } +} + +impl Type { + pub fn arrow, T2: Into>(t1: T1, t2: T2) -> Self { + Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) + } + + pub fn split_arrow(self) -> Option<(Type, Type)> { + match self { + Type::Arrow(lhs, rhs) => Some((*lhs, *rhs)), + _ => None, + } + } + + pub fn is_concrete(&self) -> bool { + match self { + Type::Variable(_) => false, + Type::Primitive(_) => true, + Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(), + } + } + + pub fn has_tag(&self, tag: &TypeTag) -> bool { + match self { + Type::Variable(_) => true, + Type::Primitive(primitive_type) => matches!( + (primitive_type, tag), + (_, TypeTag::Any) | (PrimitiveType::Nat | PrimitiveType::Float, TypeTag::Num) + ), + Type::Arrow(_, _) => false, + } + } + + pub fn map_name(&mut self, f: &F) { + match self { + Type::Variable(v) => f(v), + Type::Arrow(lhs, rhs) => { + lhs.map_name(f); + rhs.map_name(f); + } + Type::Primitive(_) => {} + } + } + + pub fn name_used(&self, ident: &str) -> bool { + match self { + Type::Variable(i) if *i == ident => true, + Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident), + _ => false, + } + } + + pub fn type_var(&self) -> Option { + match self { + Type::Variable(v) => Some(v.clone()), + Type::Primitive(_) | Type::Arrow(_, _) => None, + } + } + + pub fn type_vars(&self) -> HashSet { + match self { + Type::Variable(v) => { + let mut set = HashSet::new(); + set.insert(v.to_string()); + set + } + Type::Primitive(_) => HashSet::new(), + Type::Arrow(lhs, rhs) => { + let mut vars = lhs.type_vars(); + for var in rhs.type_vars() { + vars.insert(var); + } + vars + } + } + } + + pub fn specialize(self, ident: &str, typ: &Type) -> Type { + match self { + Type::Variable(i) if i == ident => typ.clone(), + Type::Arrow(lhs, rhs) => Type::Arrow( + Box::new(lhs.specialize(ident, typ)), + Box::new(rhs.specialize(ident, typ)), + ), + t => t, + } + } +} + +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Variable(i) => write!(f, "{i}"), + Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), + Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), + } + } +} + +impl Display for PrimitiveType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PrimitiveType::Nat => write!(f, "Nat"), + PrimitiveType::Bool => write!(f, "Bool"), + PrimitiveType::Float => write!(f, "Float"), + } + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 73a4075..2596db6 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,363 +1,5 @@ -use std::{ - collections::{HashMap, HashSet}, - fmt::Display, - mem, - rc::Rc, - str, -}; +mod concrete; +mod tagged; -pub type Ident = String; - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum PrimitiveType { - Nat, - Bool, - Float, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum TypeTag { - Num, - Any, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum TaggedType { - Tagged(TypeTag, Ident, Box), - Concrete(Type), -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Type { - Variable(Ident), // a - Primitive(PrimitiveType), // Bool | Nat - Arrow(Box, Box), // 0 -> 1 -} - -impl From for TaggedType { - fn from(value: Type) -> Self { - Self::Concrete(value) - } -} - -impl TypeTag { - // If one tag tightens the other return that tag, otherwise none, - pub fn tightens(self, other: Self) -> Option { - match (self, other) { - (TypeTag::Num, TypeTag::Num) => Some(Self::Num), - (TypeTag::Num, TypeTag::Any) => Some(Self::Num), - (TypeTag::Any, TypeTag::Num) => Some(Self::Num), - (TypeTag::Any, TypeTag::Any) => Some(Self::Any), - } - } -} - -impl TaggedType { - pub fn r#type(&self) -> &Type { - match self { - TaggedType::Tagged(_, _, tagged_type) => tagged_type.r#type(), - TaggedType::Concrete(t) => t, - } - } - - pub fn type_mut(&mut self) -> &mut Type { - match self { - TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_mut(), - TaggedType::Concrete(t) => t, - } - } - - pub fn into_concrete(self) -> Option { - match self { - TaggedType::Tagged(type_tag, _, tagged_type) => None, - TaggedType::Concrete(t) => Some(t), - } - } - - pub fn matches(&self, typ: &Type) -> (bool, Option) { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - if typ.has_tag(type_tag) { - (true, Some(ident.clone())) - } else { - tagged_type.matches(typ) - } - } - TaggedType::Concrete(c) => (c == typ, None), - } - } - - pub fn split_arrow(self) -> Option<(TaggedType, TaggedType)> { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - let (lhs, rhs) = tagged_type.split_arrow()?; - Some(( - TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs)) - .clear_unused_names(), - TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(), - )) - } - TaggedType::Concrete(Type::Arrow(lhs, rhs)) => { - Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs))) - } - TaggedType::Concrete(_) => None, - } - } - - pub fn make_arrow(self, rhs: TaggedType) -> Self { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow(rhs))) - } - TaggedType::Concrete(t) => rhs.make_arrow_lhs(t), - } - } - - fn make_arrow_lhs(self, lhs: Type) -> Self { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow_lhs(lhs))) - } - TaggedType::Concrete(t) => Type::arrow(lhs, t).into(), - } - } - - pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType { - match self { - TaggedType::Tagged(_, i, tagged_type) if i == ident => { - tagged_type.specialize(ident, typ) - } - TaggedType::Tagged(t, i, tagged_type) => { - TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ))) - } - TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)), - } - } - - pub fn map_name(&mut self, f: F) { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - f(ident); - tagged_type.map_name(f); - } - TaggedType::Concrete(t) => t.map_name(&f), - } - } - - pub fn type_var(&self) -> Option { - match self { - TaggedType::Tagged(type_tag, _, tagged_type) => tagged_type.type_var(), - TaggedType::Concrete(c) => c.type_var(), - } - } - - pub fn free_vars(&self) -> HashSet { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - let mut vars = tagged_type.free_vars(); - vars.retain(|v| v != ident); - vars - } - TaggedType::Concrete(c) => c.type_vars(), - } - } - - pub fn name_used(&self, ident: &str) -> bool { - match self { - TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident), - TaggedType::Concrete(c) => c.name_used(ident), - } - } - - pub fn is_concrete(&self) -> bool { - match self { - TaggedType::Tagged(type_tag, _, tagged_type) => tagged_type.is_concrete(), - TaggedType::Concrete(c) => c.is_concrete(), - } - } - - pub fn normalise(self) -> TaggedType { - let only_used = self.clear_unused_names(); - fn dedup(this: TaggedType, mut used: Rc>) -> TaggedType { - match this { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - if !used.contains_key(&ident) { - Rc::make_mut(&mut used).insert(ident, type_tag); - dedup(*tagged_type, used) - } else { - let Some(tag) = TypeTag::tightens( - Rc::make_mut(&mut used).remove(&ident).unwrap(), - type_tag, - ) else { - todo!() - }; - Rc::make_mut(&mut used).insert(ident.clone(), tag.clone()); - dedup(*tagged_type, used) - } - } - TaggedType::Concrete(c) => mem::take(Rc::make_mut(&mut used)) - .into_iter() - .fold(TaggedType::Concrete(c), |acc, (i, t)| { - TaggedType::Tagged(t, i, Box::new(acc)) - }), - } - } - - dedup(only_used, Rc::new(HashMap::new())) - } - - fn clear_unused_names(self) -> TaggedType { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - if tagged_type.name_used(&ident) { - TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.clear_unused_names())) - } else { - tagged_type.clear_unused_names() - } - } - t => t, - } - } -} - -impl From for Type { - fn from(value: PrimitiveType) -> Self { - Self::Primitive(value) - } -} - -impl From<&str> for Type { - fn from(value: &str) -> Self { - Self::Variable(value.to_string()) - } -} - -impl Type { - pub fn arrow, T2: Into>(t1: T1, t2: T2) -> Self { - Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) - } - - pub fn split_arrow(self) -> Option<(Type, Type)> { - match self { - Type::Arrow(lhs, rhs) => Some((*lhs, *rhs)), - _ => None, - } - } - - pub fn is_concrete(&self) -> bool { - match self { - Type::Variable(_) => false, - Type::Primitive(primitive_type) => true, - Type::Arrow(t1, t2) => t1.is_concrete() && t2.is_concrete(), - } - } - - pub fn has_tag(&self, tag: &TypeTag) -> bool { - match self { - Type::Variable(_) => true, - Type::Primitive(primitive_type) => matches!( - (primitive_type, tag), - (_, TypeTag::Any) - | (PrimitiveType::Nat, TypeTag::Num) - | (PrimitiveType::Float, TypeTag::Num) - ), - Type::Arrow(_, _) => false, - } - } - - pub fn map_name(&mut self, f: &F) { - match self { - Type::Variable(v) => f(v), - Type::Arrow(lhs, rhs) => { - lhs.map_name(f); - rhs.map_name(f); - } - _ => {} - } - } - - pub fn name_used(&self, ident: &str) -> bool { - match self { - Type::Variable(i) if *i == ident => true, - Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident), - _ => false, - } - } - - pub fn type_var(&self) -> Option { - match self { - Type::Variable(v) => Some(v.clone()), - Type::Primitive(primitive_type) => None, - Type::Arrow(_, _) => None, - } - } - - pub fn type_vars(&self) -> HashSet { - match self { - Type::Variable(v) => { - let mut set = HashSet::new(); - set.insert(v.to_string()); - set - } - Type::Primitive(primitive_type) => HashSet::new(), - Type::Arrow(lhs, rhs) => { - let mut vars = lhs.type_vars(); - for var in rhs.type_vars().into_iter() { - vars.insert(var); - } - vars - } - } - } - - fn specialize(self, ident: &str, typ: &Type) -> Type { - match self { - Type::Variable(i) if i == ident => typ.clone(), - Type::Arrow(lhs, rhs) => Type::Arrow( - Box::new(lhs.specialize(ident, typ)), - Box::new(rhs.specialize(ident, typ)), - ), - t => t, - } - } -} - -impl Display for TypeTag { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TypeTag::Num => write!(f, "Num"), - TypeTag::Any => write!(f, "Any"), - } - } -} - -impl Display for TaggedType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TaggedType::Tagged(type_tag, ident, tagged_type) => { - write!(f, "{type_tag} {ident} => {tagged_type}") - } - TaggedType::Concrete(t) => write!(f, "{t}"), - } - } -} - -impl Display for Type { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Type::Variable(i) => write!(f, "{i}"), - Type::Primitive(primitive_type) => write!(f, "{primitive_type}"), - Type::Arrow(t1, t2) => write!(f, "({t1} -> {t2})"), - } - } -} - -impl Display for PrimitiveType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - PrimitiveType::Nat => write!(f, "Nat"), - PrimitiveType::Bool => write!(f, "Bool"), - PrimitiveType::Float => write!(f, "Float"), - } - } -} +pub use concrete::{PrimitiveType, Type}; +pub use tagged::{TaggedType, TypeTag}; diff --git a/src/types/tagged.rs b/src/types/tagged.rs new file mode 100644 index 0000000..505f5ae --- /dev/null +++ b/src/types/tagged.rs @@ -0,0 +1,180 @@ +use std::{collections::HashSet, fmt::Display}; + +use crate::parse::Ident; + +use super::concrete::Type; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum TypeTag { + Num, + Any, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum TaggedType { + Tagged(TypeTag, Ident, Box), + Concrete(Type), +} + +impl From for TaggedType { + fn from(value: Type) -> Self { + Self::Concrete(value) + } +} + +impl TypeTag { + // If one tag tightens the other return that tag, otherwise none, + pub fn tightens(self, other: Self) -> Option { + #[allow(unreachable_patterns)] // #[non_exhaustive] doesn't apply with tuples + match (self, other) { + (TypeTag::Any, t) | (t, TypeTag::Any) => Some(t), + (TypeTag::Num, TypeTag::Num) => Some(Self::Num), + _ => None, + } + } +} + +impl TaggedType { + pub fn into_concrete(self) -> Option { + match self { + TaggedType::Tagged(_, _, _) => None, + TaggedType::Concrete(t) => Some(t), + } + } + + pub fn matches(&self, typ: &Type) -> (bool, Option) { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + if typ.has_tag(type_tag) { + (true, Some(ident.clone())) + } else { + tagged_type.matches(typ) + } + } + TaggedType::Concrete(c) => (c == typ, None), + } + } + + pub fn split_arrow(self) -> Option<(TaggedType, TaggedType)> { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + let (lhs, rhs) = tagged_type.split_arrow()?; + Some(( + TaggedType::Tagged(type_tag.clone(), ident.clone(), Box::new(lhs)) + .clear_unused_names(), + TaggedType::Tagged(type_tag, ident, Box::new(rhs)).clear_unused_names(), + )) + } + TaggedType::Concrete(Type::Arrow(lhs, rhs)) => { + Some((TaggedType::Concrete(*lhs), TaggedType::Concrete(*rhs))) + } + TaggedType::Concrete(_) => None, + } + } + + pub fn make_arrow(self, rhs: TaggedType) -> Self { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow(rhs))) + } + TaggedType::Concrete(t) => rhs.make_arrow_lhs(t), + } + } + + fn make_arrow_lhs(self, lhs: Type) -> Self { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.make_arrow_lhs(lhs))) + } + TaggedType::Concrete(t) => Type::arrow(lhs, t).into(), + } + } + + pub fn specialize(self, ident: &str, typ: &Type) -> TaggedType { + match self { + TaggedType::Tagged(_, i, tagged_type) if i == ident => { + tagged_type.specialize(ident, typ) + } + TaggedType::Tagged(t, i, tagged_type) => { + TaggedType::Tagged(t, i, Box::new(tagged_type.specialize(ident, typ))) + } + TaggedType::Concrete(c) => TaggedType::Concrete(c.specialize(ident, typ)), + } + } + + pub fn map_name(&mut self, f: F) { + match self { + TaggedType::Tagged(_, ident, tagged_type) => { + f(ident); + tagged_type.map_name(f); + } + TaggedType::Concrete(t) => t.map_name(&f), + } + } + + pub fn type_var(&self) -> Option { + match self { + TaggedType::Tagged(_, _, tagged_type) => tagged_type.type_var(), + TaggedType::Concrete(c) => c.type_var(), + } + } + + pub fn free_vars(&self) -> HashSet { + match self { + TaggedType::Tagged(_, ident, tagged_type) => { + let mut vars = tagged_type.free_vars(); + vars.retain(|v| v != ident); + vars + } + TaggedType::Concrete(c) => c.type_vars(), + } + } + + pub fn name_used(&self, ident: &str) -> bool { + match self { + TaggedType::Tagged(_, _, tagged_type) => tagged_type.name_used(ident), + TaggedType::Concrete(c) => c.name_used(ident), + } + } + + pub fn is_concrete(&self) -> bool { + match self { + TaggedType::Tagged(_, _, tagged_type) => tagged_type.is_concrete(), + TaggedType::Concrete(c) => c.is_concrete(), + } + } + + fn clear_unused_names(self) -> TaggedType { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + if tagged_type.name_used(&ident) { + TaggedType::Tagged(type_tag, ident, Box::new(tagged_type.clear_unused_names())) + } else { + tagged_type.clear_unused_names() + } + } + t @ TaggedType::Concrete(_) => t, + } + } +} + +impl Display for TypeTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeTag::Num => write!(f, "Num"), + TypeTag::Any => write!(f, "Any"), + } + } +} + +impl Display for TaggedType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TaggedType::Tagged(type_tag, ident, tagged_type) => { + write!(f, "{type_tag} {ident} => {tagged_type}") + } + TaggedType::Concrete(t) => write!(f, "{t}"), + } + } +} diff --git a/src/vec_map.rs b/src/vec_map.rs index 646b24c..8343cd3 100644 --- a/src/vec_map.rs +++ b/src/vec_map.rs @@ -19,13 +19,11 @@ impl VecMap { } pub fn get(&self, key: &K) -> Option<&V> { - self.map - .iter() - .find_map(|(k, v)| if k == key { Some(v) } else { None }) + self.map.iter().find_map(|(k, v)| (k == key).then_some(v)) } pub fn map_keys(&mut self, f: F) { - self.map.iter_mut().for_each(|(k, _)| f(k)) + self.map.iter_mut().for_each(|(k, _)| f(k)); } }