main
Avery 1 month ago
parent 6a46c6ca52
commit 32f8f7c29a
Signed by: Avery
GPG Key ID: 4E53F4CB69B2CC8D

@ -0,0 +1,352 @@
use std::rc::Rc;
use crate::{
parse::Constant,
types::{PrimitiveType, TaggedType, Type, TypeTag},
};
use super::{
DeBrujinAst,
builtins::{Builtin, DeBrujinBuiltInAst},
};
pub struct AddOp;
struct AddOpNat(usize);
struct AddOpFloat(f64);
impl Builtin for AddOp {
fn name(&self) -> String {
"add".to_string()
}
fn r#type(&self) -> TaggedType {
TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow(
"a",
Type::arrow("a", "a"),
))),
)
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(AddOpNat(n))))
}
DeBrujinBuiltInAst::Constant(Constant::Float(n)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(AddOpFloat(n))))
}
_ => None,
}
}
}
impl Builtin for AddOpNat {
fn name(&self) -> String {
format!("add{}", self.0)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("add".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Nat(self.0))),
)
}
fn r#type(&self) -> TaggedType {
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => {
Some(DeBrujinBuiltInAst::Constant(Constant::Nat(n + self.0)))
}
_ => None,
}
}
}
impl Builtin for AddOpFloat {
fn name(&self) -> String {
format!("add{}", self.0)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("add".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Float(self.0))),
)
}
fn r#type(&self) -> TaggedType {
Type::arrow(PrimitiveType::Float, PrimitiveType::Float).into()
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Float(n)) => {
Some(DeBrujinBuiltInAst::Constant(Constant::Float(n + self.0)))
}
_ => None,
}
}
}
pub struct SubOp;
struct SubOpNat(usize);
struct SubOpFloat(f64);
impl Builtin for SubOp {
fn name(&self) -> String {
"sub".to_string()
}
fn r#type(&self) -> TaggedType {
TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow(
"a",
Type::arrow("a", "a"),
))),
)
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(SubOpNat(n))))
}
DeBrujinBuiltInAst::Constant(Constant::Float(n)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(SubOpFloat(n))))
}
_ => None,
}
}
}
impl Builtin for SubOpNat {
fn name(&self) -> String {
format!("sub{}", self.0)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("sub".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Nat(self.0))),
)
}
fn r#type(&self) -> TaggedType {
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => {
Some(DeBrujinBuiltInAst::Constant(Constant::Nat(self.0 - n)))
}
_ => None,
}
}
}
impl Builtin for SubOpFloat {
fn name(&self) -> String {
format!("sub{}", self.0)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("sub".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Float(self.0))),
)
}
fn r#type(&self) -> TaggedType {
Type::arrow(PrimitiveType::Float, PrimitiveType::Float).into()
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Float(n)) => {
Some(DeBrujinBuiltInAst::Constant(Constant::Float(self.0 - n)))
}
_ => None,
}
}
}
pub struct MulOp;
struct MulOpNat(usize);
struct MulOpFloat(f64);
impl Builtin for MulOp {
fn name(&self) -> String {
"mul".to_string()
}
fn r#type(&self) -> TaggedType {
TaggedType::Tagged(
TypeTag::Num,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow(
"a",
Type::arrow("a", "a"),
))),
)
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(MulOpNat(n))))
}
DeBrujinBuiltInAst::Constant(Constant::Float(n)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(MulOpFloat(n))))
}
_ => None,
}
}
}
impl Builtin for MulOpNat {
fn name(&self) -> String {
format!("mul{}", self.0)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("mul".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Nat(self.0))),
)
}
fn r#type(&self) -> TaggedType {
Type::arrow(PrimitiveType::Nat, PrimitiveType::Nat).into()
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Nat(n)) => {
Some(DeBrujinBuiltInAst::Constant(Constant::Nat(n * self.0)))
}
_ => None,
}
}
}
impl Builtin for MulOpFloat {
fn name(&self) -> String {
format!("mul{}", self.0)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("mul".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Float(self.0))),
)
}
fn r#type(&self) -> TaggedType {
Type::arrow(PrimitiveType::Float, PrimitiveType::Float).into()
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Float(n)) => {
Some(DeBrujinBuiltInAst::Constant(Constant::Float(n * self.0)))
}
_ => None,
}
}
}
pub struct OpCond;
struct OpCond1(bool);
struct OpCond2(bool, DeBrujinBuiltInAst);
impl Builtin for OpCond {
fn name(&self) -> String {
"if".to_string()
}
fn r#type(&self) -> TaggedType {
TaggedType::Tagged(
TypeTag::Any,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow(
PrimitiveType::Bool,
Type::arrow("a", Type::arrow("a", "a")),
))),
)
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
match rhs {
DeBrujinBuiltInAst::Constant(Constant::Bool(b)) => {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(OpCond1(b))))
}
_ => None,
}
}
}
impl Builtin for OpCond1 {
fn name(&self) -> String {
format!("if{}1", self.0)
}
fn r#type(&self) -> TaggedType {
TaggedType::Tagged(
TypeTag::Any,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow(
"a",
Type::arrow("a", "a"),
))),
)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("if".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Bool(self.0))),
)
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
Some(DeBrujinBuiltInAst::Builtin(Rc::new(OpCond2(self.0, rhs))))
}
}
impl Builtin for OpCond2 {
fn name(&self) -> String {
format!("if{}2", self.0)
}
fn r#type(&self) -> TaggedType {
TaggedType::Tagged(
TypeTag::Any,
"a".to_string(),
Box::new(TaggedType::Concrete(Type::arrow("a", "a"))),
)
}
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::Application(
Box::new(DeBrujinAst::Application(
Box::new(DeBrujinAst::FreeVariable("if".to_string())),
Box::new(DeBrujinAst::Constant(Constant::Bool(self.0))),
)),
Box::new(self.1.clone().into()),
)
}
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst> {
Some(if self.0 { self.1.clone() } else { rhs })
}
}

@ -0,0 +1,101 @@
use std::{collections::HashMap, rc::Rc, usize};
use crate::{
parse::Constant,
types::{Ident, PrimitiveType, TaggedType, Type, TypeTag},
};
use super::DeBrujinAst;
#[derive(Clone)]
pub enum DeBrujinBuiltInAst {
Abstraction(Ident, Type, Box<DeBrujinBuiltInAst>), // \:1.2
Application(Box<DeBrujinBuiltInAst>, Box<DeBrujinBuiltInAst>), // 0 1
FreeVariable(String), // x
BoundVariable(usize), // 1
Constant(Constant), // true | false | n
Builtin(Rc<dyn Builtin>),
}
impl DeBrujinAst {
pub fn resolve_builtins(
self,
builtins: &HashMap<String, Rc<dyn Builtin>>,
) -> DeBrujinBuiltInAst {
match self {
DeBrujinAst::Abstraction(i, t, ast) => {
DeBrujinBuiltInAst::Abstraction(i, t, Box::new(ast.resolve_builtins(builtins)))
}
DeBrujinAst::Application(lhs, rhs) => DeBrujinBuiltInAst::Application(
Box::new(lhs.resolve_builtins(builtins)),
Box::new(rhs.resolve_builtins(builtins)),
),
DeBrujinAst::FreeVariable(x) => {
if let Some(b) = builtins.get(&x) {
DeBrujinBuiltInAst::Builtin(b.clone())
} else {
DeBrujinBuiltInAst::FreeVariable(x)
}
}
DeBrujinAst::BoundVariable(b) => DeBrujinBuiltInAst::BoundVariable(b),
DeBrujinAst::Constant(c) => DeBrujinBuiltInAst::Constant(c),
}
}
}
pub trait Builtin {
fn name(&self) -> String;
fn to_ast(&self) -> DeBrujinAst {
DeBrujinAst::FreeVariable(self.name())
}
fn r#type(&self) -> TaggedType;
fn apply(&self, rhs: DeBrujinBuiltInAst) -> Option<DeBrujinBuiltInAst>;
}
impl DeBrujinAst {
pub fn reduce_builtins(self, builtins: &HashMap<String, Rc<dyn Builtin>>) -> DeBrujinAst {
self.resolve_builtins(builtins).reduce_builtins().into()
}
}
impl DeBrujinBuiltInAst {
fn reduce_builtins(self) -> DeBrujinBuiltInAst {
match self {
DeBrujinBuiltInAst::Abstraction(i, t, ast) => {
DeBrujinBuiltInAst::Abstraction(i, t, Box::new(ast.reduce_builtins()))
}
DeBrujinBuiltInAst::Application(lhs, rhs) => match *lhs {
DeBrujinBuiltInAst::Builtin(builtin) => builtin
.apply(*rhs)
.expect("the type checker should make sure we can apply")
.reduce_builtins(),
lhs => DeBrujinBuiltInAst::Application(
Box::new(lhs.reduce_builtins()),
Box::new(rhs.reduce_builtins()),
)
.reduce_builtins(),
},
a => a,
}
}
}
impl Into<DeBrujinAst> for DeBrujinBuiltInAst {
fn into(self) -> DeBrujinAst {
match self {
DeBrujinBuiltInAst::Abstraction(i, t, ast) => {
DeBrujinAst::Abstraction(i, t, Box::new((*ast).into()))
}
DeBrujinBuiltInAst::Application(lhs, rhs) => {
DeBrujinAst::Application(Box::new((*lhs).into()), Box::new((*rhs).into()))
}
DeBrujinBuiltInAst::FreeVariable(x) => DeBrujinAst::FreeVariable(x),
DeBrujinBuiltInAst::BoundVariable(i) => DeBrujinAst::BoundVariable(i),
DeBrujinBuiltInAst::Constant(constant) => DeBrujinAst::Constant(constant),
DeBrujinBuiltInAst::Builtin(builtin) => DeBrujinAst::FreeVariable(builtin.name()),
}
}
}

@ -1,6 +1,11 @@
pub mod builtin_definitions;
mod builtins;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub use builtin_definitions as builtin;
pub use builtins::Builtin;
use std::{collections::HashMap, rc::Rc}; use std::{collections::HashMap, rc::Rc};
use crate::{ use crate::{
@ -41,7 +46,7 @@ impl Ast {
} }
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq)]
pub enum DeBrujinAst { pub enum DeBrujinAst {
Abstraction(Ident, Type, Box<DeBrujinAst>), // \:1.2 Abstraction(Ident, Type, Box<DeBrujinAst>), // \:1.2
Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1 Application(Box<DeBrujinAst>, Box<DeBrujinAst>), // 0 1

@ -1,9 +1,13 @@
use std::{collections::HashMap, rc::Rc};
use crate::{ use crate::{
PrimitiveType, Type, PrimitiveType, Type,
exec::DeBrujinAst as DBAst, exec::{DeBrujinAst as DBAst, builtin_definitions::AddOp},
parse::{Ast, Constant}, parse::{Ast, Constant},
}; };
use super::builtins::Builtin;
#[test] #[test]
fn to_de_brujin_ast_simple() { fn to_de_brujin_ast_simple() {
let input = Ast::Abstraction( let input = Ast::Abstraction(
@ -71,3 +75,19 @@ fn to_and_from_de_brujin_is_id() {
let output: Ast = dbast.into(); let output: Ast = dbast.into();
assert_eq!(input, output); assert_eq!(input, output);
} }
#[test]
fn reduce_add() {
let input: DBAst = Ast::Application(
Box::new(Ast::Application(
Box::new(Ast::Variable("add".to_string())),
Box::new(Ast::Constant(Constant::Nat(5))),
)),
Box::new(Ast::Constant(Constant::Nat(5))),
)
.into();
let mut builtins: HashMap<String, Rc<dyn Builtin>> = HashMap::new();
builtins.insert("add".to_string(), Rc::new(AddOp));
let output = input.reduce_builtins(&builtins);
assert_eq!(output, DBAst::Constant(Constant::Nat(10)));
}

@ -8,7 +8,9 @@ mod parse;
mod types; mod types;
mod vec_map; mod vec_map;
pub use exec::Builtin;
pub use exec::DeBrujinAst; pub use exec::DeBrujinAst;
pub use exec::builtin;
pub use inference::infer_type; pub use inference::infer_type;
pub use parse::{Ast, ParseError, is_ident, parse, parse_type, sexpr::parse_string}; pub use parse::{Ast, ParseError, is_ident, parse, parse_type, sexpr::parse_string};
use types::{Ident, PrimitiveType, Type}; use types::{Ident, PrimitiveType, Type};

@ -5,7 +5,7 @@ use std::{
}; };
use stlc_type_inference::{ use stlc_type_inference::{
Ast, DeBrujinAst, infer_type, is_ident, parse, parse_string, parse_type, Ast, Builtin, DeBrujinAst, builtin, infer_type, is_ident, parse, parse_string, parse_type,
}; };
macro_rules! repl_err { macro_rules! repl_err {
@ -18,7 +18,15 @@ macro_rules! repl_err {
} }
fn main() { fn main() {
let mut builtins: HashMap<String, Rc<dyn Builtin>> = HashMap::new();
builtins.insert("add".to_string(), Rc::new(builtin::AddOp));
builtins.insert("sub".to_string(), Rc::new(builtin::SubOp));
builtins.insert("mul".to_string(), Rc::new(builtin::MulOp));
builtins.insert("if".to_string(), Rc::new(builtin::OpCond));
let mut gamma = HashMap::new(); let mut gamma = HashMap::new();
for (k, v) in &builtins {
gamma.insert(k.clone(), v.r#type());
}
print!("> "); print!("> ");
stdout().flush().unwrap(); stdout().flush().unwrap();
for line in stdin().lines() { for line in stdin().lines() {
@ -50,8 +58,12 @@ fn main() {
Err(e) => repl_err!("type could not be parsed {e:?}"), Err(e) => repl_err!("type could not be parsed {e:?}"),
}; };
if !gamma.contains_key(&ident) {
println!("Added {ident} with type {typ} to the context"); println!("Added {ident} with type {typ} to the context");
gamma.insert(ident, typ.into()); gamma.insert(ident, typ.into());
} else {
println!("Cannot override existing ctx");
}
} }
} }
c => println!("Unknown command {c}"), c => println!("Unknown command {c}"),
@ -75,6 +87,7 @@ fn main() {
Err(e) => repl_err!("Could not infer type {e:?}"), Err(e) => repl_err!("Could not infer type {e:?}"),
}; };
let ast = ast.beta_reduce(); let ast = ast.beta_reduce();
let ast = ast.reduce_builtins(&builtins);
let ast: Ast = ast.into(); let ast: Ast = ast.into();
println!("{ast} : {typ}") println!("{ast} : {typ}")
} }

@ -29,13 +29,14 @@ pub enum ParseError {
ExpectedOneOf(Vec<String>, String), ExpectedOneOf(Vec<String>, String),
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq)]
pub enum Constant { pub enum Constant {
Nat(usize), Nat(usize),
Float(f64),
Bool(bool), Bool(bool),
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq)]
pub enum Ast { pub enum Ast {
Abstraction(Ident, Type, Box<Ast>), // \0:1.2 Abstraction(Ident, Type, Box<Ast>), // \0:1.2
Application(Box<Ast>, Box<Ast>), // 0 1 Application(Box<Ast>, Box<Ast>), // 0 1
@ -59,6 +60,7 @@ impl Display for Constant {
match self { match self {
Constant::Nat(n) => write!(f, "{n}"), Constant::Nat(n) => write!(f, "{n}"),
Constant::Bool(b) => write!(f, "{b}"), Constant::Bool(b) => write!(f, "{b}"),
Constant::Float(fl) => write!(f, "{fl}"),
} }
} }
} }
@ -172,6 +174,8 @@ fn parse_intern(ast: Sexpr) -> Result<Ast, ParseError> {
fn parse_symbol(s: String) -> Result<Ast, ParseError> { fn parse_symbol(s: String) -> Result<Ast, ParseError> {
if let Ok(n) = s.parse::<usize>() { if let Ok(n) = s.parse::<usize>() {
Ok(Ast::Constant(Constant::Nat(n))) Ok(Ast::Constant(Constant::Nat(n)))
} else if let Ok(f) = s.parse::<f64>() {
Ok(Ast::Constant(Constant::Float(f)))
} else if let Ok(b) = s.parse::<bool>() { } else if let Ok(b) = s.parse::<bool>() {
Ok(Ast::Constant(Constant::Bool(b))) Ok(Ast::Constant(Constant::Bool(b)))
} else if is_ident(&s) { } else if is_ident(&s) {

@ -6,11 +6,13 @@ pub type Ident = String;
pub enum PrimitiveType { pub enum PrimitiveType {
Nat, Nat,
Bool, Bool,
Float,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeTag { pub enum TypeTag {
Num, Num,
Any,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
@ -129,6 +131,12 @@ impl From<PrimitiveType> for Type {
} }
} }
impl From<&str> for Type {
fn from(value: &str) -> Self {
Self::Generic(value.to_string())
}
}
impl Type { impl Type {
pub fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self { pub fn arrow<T1: Into<Type>, T2: Into<Type>>(t1: T1, t2: T2) -> Self {
Self::Arrow(Box::new(t1.into()), Box::new(t2.into())) Self::Arrow(Box::new(t1.into()), Box::new(t2.into()))
@ -146,6 +154,7 @@ impl Type {
match self { match self {
Type::Generic(_) => false, Type::Generic(_) => false,
Type::Primitive(primitive_type) => match (primitive_type, tag) { Type::Primitive(primitive_type) => match (primitive_type, tag) {
(_, TypeTag::Any) => true,
(PrimitiveType::Nat, TypeTag::Num) => true, (PrimitiveType::Nat, TypeTag::Num) => true,
_ => false, _ => false,
}, },
@ -156,6 +165,7 @@ impl Type {
fn name_used(&self, ident: &str) -> bool { fn name_used(&self, ident: &str) -> bool {
match self { match self {
Type::Generic(i) if *i == ident => true, Type::Generic(i) if *i == ident => true,
Type::Arrow(lhs, rhs) => lhs.name_used(ident) || rhs.name_used(ident),
_ => false, _ => false,
} }
} }
@ -176,6 +186,7 @@ impl Display for TypeTag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
TypeTag::Num => write!(f, "Num"), TypeTag::Num => write!(f, "Num"),
TypeTag::Any => write!(f, "Any"),
} }
} }
} }
@ -206,6 +217,7 @@ impl Display for PrimitiveType {
match self { match self {
PrimitiveType::Nat => write!(f, "Nat"), PrimitiveType::Nat => write!(f, "Nat"),
PrimitiveType::Bool => write!(f, "Bool"), PrimitiveType::Bool => write!(f, "Bool"),
PrimitiveType::Float => write!(f, "Float"),
} }
} }
} }

Loading…
Cancel
Save