diff --git a/src/exec/builtin_definitions.rs b/src/exec/builtin_definitions.rs new file mode 100644 index 0000000..2188661 --- /dev/null +++ b/src/exec/builtin_definitions.rs @@ -0,0 +1,352 @@ +use std::rc::Rc; + +use crate::{ + parse::Constant, + types::{PrimitiveType, TaggedType, Type, TypeTag}, +}; + +use super::{ + DeBrujinAst, + builtins::{Builtin, DeBrujinBuiltInAst}, +}; + +pub struct AddOp; +struct AddOpNat(usize); +struct AddOpFloat(f64); + +impl Builtin for AddOp { + fn name(&self) -> String { + "add".to_string() + } + + fn r#type(&self) -> TaggedType { + TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + "a", + Type::arrow("a", "a"), + ))), + ) + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(AddOpNat(n)))) + } + DeBrujinBuiltInAst::Constant(Constant::Float(n)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(AddOpFloat(n)))) + } + _ => None, + } + } +} + +impl Builtin for AddOpNat { + fn name(&self) -> String { + format!("add{}", self.0) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("add".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Nat(self.0))), + ) + } + + fn r#type(&self) -> TaggedType { + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into() + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => { + Some(DeBrujinBuiltInAst::Constant(Constant::Nat(n + self.0))) + } + _ => None, + } + } +} + +impl Builtin for AddOpFloat { + fn name(&self) -> String { + format!("add{}", self.0) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("add".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Float(self.0))), + ) + } + + fn r#type(&self) -> TaggedType { + Type::arrow(PrimitiveType::Float, PrimitiveType::Float).into() + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Float(n)) => { + Some(DeBrujinBuiltInAst::Constant(Constant::Float(n + self.0))) + } + _ => None, + } + } +} + +pub struct SubOp; +struct SubOpNat(usize); +struct SubOpFloat(f64); + +impl Builtin for SubOp { + fn name(&self) -> String { + "sub".to_string() + } + + fn r#type(&self) -> TaggedType { + TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + "a", + Type::arrow("a", "a"), + ))), + ) + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(SubOpNat(n)))) + } + DeBrujinBuiltInAst::Constant(Constant::Float(n)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(SubOpFloat(n)))) + } + _ => None, + } + } +} + +impl Builtin for SubOpNat { + fn name(&self) -> String { + format!("sub{}", self.0) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("sub".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Nat(self.0))), + ) + } + + fn r#type(&self) -> TaggedType { + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into() + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => { + Some(DeBrujinBuiltInAst::Constant(Constant::Nat(self.0 - n))) + } + _ => None, + } + } +} + +impl Builtin for SubOpFloat { + fn name(&self) -> String { + format!("sub{}", self.0) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("sub".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Float(self.0))), + ) + } + + fn r#type(&self) -> TaggedType { + Type::arrow(PrimitiveType::Float, PrimitiveType::Float).into() + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Float(n)) => { + Some(DeBrujinBuiltInAst::Constant(Constant::Float(self.0 - n))) + } + _ => None, + } + } +} + +pub struct MulOp; +struct MulOpNat(usize); +struct MulOpFloat(f64); + +impl Builtin for MulOp { + fn name(&self) -> String { + "mul".to_string() + } + + fn r#type(&self) -> TaggedType { + TaggedType::Tagged( + TypeTag::Num, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + "a", + Type::arrow("a", "a"), + ))), + ) + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(MulOpNat(n)))) + } + DeBrujinBuiltInAst::Constant(Constant::Float(n)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(MulOpFloat(n)))) + } + _ => None, + } + } +} + +impl Builtin for MulOpNat { + fn name(&self) -> String { + format!("mul{}", self.0) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("mul".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Nat(self.0))), + ) + } + + fn r#type(&self) -> TaggedType { + Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into() + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => { + Some(DeBrujinBuiltInAst::Constant(Constant::Nat(n * self.0))) + } + _ => None, + } + } +} + +impl Builtin for MulOpFloat { + fn name(&self) -> String { + format!("mul{}", self.0) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("mul".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Float(self.0))), + ) + } + + fn r#type(&self) -> TaggedType { + Type::arrow(PrimitiveType::Float, PrimitiveType::Float).into() + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Float(n)) => { + Some(DeBrujinBuiltInAst::Constant(Constant::Float(n * self.0))) + } + _ => None, + } + } +} + +pub struct OpCond; +struct OpCond1(bool); +struct OpCond2(bool, DeBrujinBuiltInAst); + +impl Builtin for OpCond { + fn name(&self) -> String { + "if".to_string() + } + + fn r#type(&self) -> TaggedType { + TaggedType::Tagged( + TypeTag::Any, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + PrimitiveType::Bool, + Type::arrow("a", Type::arrow("a", "a")), + ))), + ) + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + match rhs { + DeBrujinBuiltInAst::Constant(Constant::Bool(b)) => { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(OpCond1(b)))) + } + _ => None, + } + } +} + +impl Builtin for OpCond1 { + fn name(&self) -> String { + format!("if{}1", self.0) + } + + fn r#type(&self) -> TaggedType { + TaggedType::Tagged( + TypeTag::Any, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow( + "a", + Type::arrow("a", "a"), + ))), + ) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("if".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Bool(self.0))), + ) + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + Some(DeBrujinBuiltInAst::Builtin(Rc::new(OpCond2(self.0, rhs)))) + } +} + +impl Builtin for OpCond2 { + fn name(&self) -> String { + format!("if{}2", self.0) + } + + fn r#type(&self) -> TaggedType { + TaggedType::Tagged( + TypeTag::Any, + "a".to_string(), + Box::new(TaggedType::Concrete(Type::arrow("a", "a"))), + ) + } + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::Application( + Box::new(DeBrujinAst::Application( + Box::new(DeBrujinAst::FreeVariable("if".to_string())), + Box::new(DeBrujinAst::Constant(Constant::Bool(self.0))), + )), + Box::new(self.1.clone().into()), + ) + } + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option { + Some(if self.0 { self.1.clone() } else { rhs }) + } +} diff --git a/src/exec/builtins.rs b/src/exec/builtins.rs new file mode 100644 index 0000000..08400e7 --- /dev/null +++ b/src/exec/builtins.rs @@ -0,0 +1,101 @@ +use std::{collections::HashMap, rc::Rc, usize}; + +use crate::{ + parse::Constant, + types::{Ident, PrimitiveType, TaggedType, Type, TypeTag}, +}; + +use super::DeBrujinAst; + +#[derive(Clone)] +pub enum DeBrujinBuiltInAst { + Abstraction(Ident, Type, Box), // \:1.2 + Application(Box, Box), // 0 1 + FreeVariable(String), // x + BoundVariable(usize), // 1 + Constant(Constant), // true | false | n + Builtin(Rc), +} + +impl DeBrujinAst { + pub fn resolve_builtins( + self, + builtins: &HashMap>, + ) -> DeBrujinBuiltInAst { + match self { + DeBrujinAst::Abstraction(i, t, ast) => { + DeBrujinBuiltInAst::Abstraction(i, t, Box::new(ast.resolve_builtins(builtins))) + } + DeBrujinAst::Application(lhs, rhs) => DeBrujinBuiltInAst::Application( + Box::new(lhs.resolve_builtins(builtins)), + Box::new(rhs.resolve_builtins(builtins)), + ), + DeBrujinAst::FreeVariable(x) => { + if let Some(b) = builtins.get(&x) { + DeBrujinBuiltInAst::Builtin(b.clone()) + } else { + DeBrujinBuiltInAst::FreeVariable(x) + } + } + DeBrujinAst::BoundVariable(b) => DeBrujinBuiltInAst::BoundVariable(b), + DeBrujinAst::Constant(c) => DeBrujinBuiltInAst::Constant(c), + } + } +} + +pub trait Builtin { + fn name(&self) -> String; + + fn to_ast(&self) -> DeBrujinAst { + DeBrujinAst::FreeVariable(self.name()) + } + + fn r#type(&self) -> TaggedType; + + fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option; +} + +impl DeBrujinAst { + pub fn reduce_builtins(self, builtins: &HashMap>) -> DeBrujinAst { + self.resolve_builtins(builtins).reduce_builtins().into() + } +} + +impl DeBrujinBuiltInAst { + fn reduce_builtins(self) -> DeBrujinBuiltInAst { + match self { + DeBrujinBuiltInAst::Abstraction(i, t, ast) => { + DeBrujinBuiltInAst::Abstraction(i, t, Box::new(ast.reduce_builtins())) + } + DeBrujinBuiltInAst::Application(lhs, rhs) => match *lhs { + DeBrujinBuiltInAst::Builtin(builtin) => builtin + .apply(*rhs) + .expect("the type checker should make sure we can apply") + .reduce_builtins(), + lhs => DeBrujinBuiltInAst::Application( + Box::new(lhs.reduce_builtins()), + Box::new(rhs.reduce_builtins()), + ) + .reduce_builtins(), + }, + a => a, + } + } +} + +impl Into for DeBrujinBuiltInAst { + fn into(self) -> DeBrujinAst { + match self { + DeBrujinBuiltInAst::Abstraction(i, t, ast) => { + DeBrujinAst::Abstraction(i, t, Box::new((*ast).into())) + } + DeBrujinBuiltInAst::Application(lhs, rhs) => { + DeBrujinAst::Application(Box::new((*lhs).into()), Box::new((*rhs).into())) + } + DeBrujinBuiltInAst::FreeVariable(x) => DeBrujinAst::FreeVariable(x), + DeBrujinBuiltInAst::BoundVariable(i) => DeBrujinAst::BoundVariable(i), + DeBrujinBuiltInAst::Constant(constant) => DeBrujinAst::Constant(constant), + DeBrujinBuiltInAst::Builtin(builtin) => DeBrujinAst::FreeVariable(builtin.name()), + } + } +} diff --git a/src/exec/mod.rs b/src/exec/mod.rs index 2e30ec8..b7d6c9a 100644 --- a/src/exec/mod.rs +++ b/src/exec/mod.rs @@ -1,6 +1,11 @@ +pub mod builtin_definitions; +mod builtins; #[cfg(test)] mod test; +pub use builtin_definitions as builtin; +pub use builtins::Builtin; + use std::{collections::HashMap, rc::Rc}; use crate::{ @@ -41,7 +46,7 @@ impl Ast { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] pub enum DeBrujinAst { Abstraction(Ident, Type, Box), // \:1.2 Application(Box, Box), // 0 1 diff --git a/src/exec/test.rs b/src/exec/test.rs index f638bbf..f11a310 100644 --- a/src/exec/test.rs +++ b/src/exec/test.rs @@ -1,9 +1,13 @@ +use std::{collections::HashMap, rc::Rc}; + use crate::{ PrimitiveType, Type, - exec::DeBrujinAst as DBAst, + exec::{DeBrujinAst as DBAst, builtin_definitions::AddOp}, parse::{Ast, Constant}, }; +use super::builtins::Builtin; + #[test] fn to_de_brujin_ast_simple() { let input = Ast::Abstraction( @@ -71,3 +75,19 @@ fn to_and_from_de_brujin_is_id() { let output: Ast = dbast.into(); assert_eq!(input, output); } + +#[test] +fn reduce_add() { + let input: DBAst = Ast::Application( + Box::new(Ast::Application( + Box::new(Ast::Variable("add".to_string())), + Box::new(Ast::Constant(Constant::Nat(5))), + )), + Box::new(Ast::Constant(Constant::Nat(5))), + ) + .into(); + let mut builtins: HashMap> = HashMap::new(); + builtins.insert("add".to_string(), Rc::new(AddOp)); + let output = input.reduce_builtins(&builtins); + assert_eq!(output, DBAst::Constant(Constant::Nat(10))); +} diff --git a/src/lib.rs b/src/lib.rs index 3153f3d..f1dc7e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,9 @@ mod parse; mod types; mod vec_map; +pub use exec::Builtin; pub use exec::DeBrujinAst; +pub use exec::builtin; pub use inference::infer_type; pub use parse::{Ast, ParseError, is_ident, parse, parse_type, sexpr::parse_string}; use types::{Ident, PrimitiveType, Type}; diff --git a/src/main.rs b/src/main.rs index a7cfaa3..eb3ddb1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use std::{ }; use stlc_type_inference::{ - Ast, DeBrujinAst, infer_type, is_ident, parse, parse_string, parse_type, + Ast, Builtin, DeBrujinAst, builtin, infer_type, is_ident, parse, parse_string, parse_type, }; macro_rules! repl_err { @@ -18,7 +18,15 @@ macro_rules! repl_err { } fn main() { + let mut builtins: HashMap> = HashMap::new(); + builtins.insert("add".to_string(), Rc::new(builtin::AddOp)); + builtins.insert("sub".to_string(), Rc::new(builtin::SubOp)); + builtins.insert("mul".to_string(), Rc::new(builtin::MulOp)); + builtins.insert("if".to_string(), Rc::new(builtin::OpCond)); let mut gamma = HashMap::new(); + for (k, v) in &builtins { + gamma.insert(k.clone(), v.r#type()); + } print!("> "); stdout().flush().unwrap(); for line in stdin().lines() { @@ -50,8 +58,12 @@ fn main() { Err(e) => repl_err!("type could not be parsed {e:?}"), }; - println!("Added {ident} with type {typ} to the context"); - gamma.insert(ident, typ.into()); + if !gamma.contains_key(&ident) { + println!("Added {ident} with type {typ} to the context"); + gamma.insert(ident, typ.into()); + } else { + println!("Cannot override existing ctx"); + } } } c => println!("Unknown command {c}"), @@ -75,6 +87,7 @@ fn main() { Err(e) => repl_err!("Could not infer type {e:?}"), }; let ast = ast.beta_reduce(); + let ast = ast.reduce_builtins(&builtins); let ast: Ast = ast.into(); println!("{ast} : {typ}") } diff --git a/src/parse/mod.rs b/src/parse/mod.rs index f29703b..6141b0b 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -29,13 +29,14 @@ pub enum ParseError { ExpectedOneOf(Vec, String), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] pub enum Constant { Nat(usize), + Float(f64), Bool(bool), } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq)] pub enum Ast { Abstraction(Ident, Type, Box), // \0:1.2 Application(Box, Box), // 0 1 @@ -59,6 +60,7 @@ impl Display for Constant { match self { Constant::Nat(n) => write!(f, "{n}"), Constant::Bool(b) => write!(f, "{b}"), + Constant::Float(fl) => write!(f, "{fl}"), } } } @@ -172,6 +174,8 @@ fn parse_intern(ast: Sexpr) -> Result { fn parse_symbol(s: String) -> Result { if let Ok(n) = s.parse::() { Ok(Ast::Constant(Constant::Nat(n))) + } else if let Ok(f) = s.parse::() { + Ok(Ast::Constant(Constant::Float(f))) } else if let Ok(b) = s.parse::() { Ok(Ast::Constant(Constant::Bool(b))) } else if is_ident(&s) { diff --git a/src/types/mod.rs b/src/types/mod.rs index a274778..a4aaafc 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -6,11 +6,13 @@ pub type Ident = String; pub enum PrimitiveType { Nat, Bool, + Float, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeTag { Num, + Any, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -129,6 +131,12 @@ impl From for Type { } } +impl From<&str> for Type { + fn from(value: &str) -> Self { + Self::Generic(value.to_string()) + } +} + impl Type { pub fn arrow, T2: Into>(t1: T1, t2: T2) -> Self { Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) @@ -146,6 +154,7 @@ impl Type { match self { Type::Generic(_) => false, Type::Primitive(primitive_type) => match (primitive_type, tag) { + (_, TypeTag::Any) => true, (PrimitiveType::Nat, TypeTag::Num) => true, _ => false, }, @@ -156,6 +165,7 @@ impl Type { fn name_used(&self, ident: &str) -> bool { match self { Type::Generic(i) if *i == ident => true, + Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident), _ => false, } } @@ -176,6 +186,7 @@ impl Display for TypeTag { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TypeTag::Num => write!(f, "Num"), + TypeTag::Any => write!(f, "Any"), } } } @@ -206,6 +217,7 @@ impl Display for PrimitiveType { match self { PrimitiveType::Nat => write!(f, "Nat"), PrimitiveType::Bool => write!(f, "Bool"), + PrimitiveType::Float => write!(f, "Float"), } } }