diff --git a/Cargo.lock b/Cargo.lock index 303b1b78116c921043f2240dd71725ecc777fa33..49630436a0f0b90d8252824046c29f0e18b78af2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -867,6 +867,7 @@ name = "hercules_ir" version = "0.1.0" dependencies = [ "bitvec", + "either", "nom", "ordered-float", "rand", @@ -1133,6 +1134,16 @@ dependencies = [ "with_builtin_macros", ] +[[package]] +name = "juno_patterns" +version = "0.1.0" +dependencies = [ + "async-std", + "hercules_rt", + "juno_build", + "with_builtin_macros", +] + [[package]] name = "juno_schedule_test" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index c7aa9428bc8818960d432d87a78fb3ba94930263..ced011a96c96793891228876314debeabcb561ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = [ "hercules_samples/ccp", "juno_samples/simple3", + "juno_samples/patterns", "juno_samples/matmul", "juno_samples/casts_and_intrinsics", "juno_samples/nested_ccp", diff --git a/hercules_cg/src/cpu.rs b/hercules_cg/src/cpu.rs index ea326f8a0310fa082c240b5f52000f9c79e0be57..344554b65280f5ff5dc2979034048cd20e86f3bb 100644 --- a/hercules_cg/src/cpu.rs +++ b/hercules_cg/src/cpu.rs @@ -589,12 +589,12 @@ impl<'a> CPUContext<'a> { ) -> Result<(), Error> { let body = &mut block.body; for dc in dynamic_constants_bottom_up(&self.dynamic_constants) { - match self.dynamic_constants[dc.idx()] { + match &self.dynamic_constants[dc.idx()] { DynamicConstant::Constant(val) => { write!(body, " %dc{} = bitcast i64 {} to i64\n", dc.idx(), val)? } DynamicConstant::Parameter(idx) => { - if idx < num_dc_params as usize { + if *idx < num_dc_params as usize { write!( body, " %dc{} = bitcast i64 %dc_p{} to i64\n", @@ -605,13 +605,31 @@ impl<'a> CPUContext<'a> { write!(body, " %dc{} = bitcast i64 0 to i64\n", dc.idx())? } } - DynamicConstant::Add(left, right) => write!( - body, - " %dc{} = add i64%dc{},%dc{}\n", - dc.idx(), - left.idx(), - right.idx() - )?, + DynamicConstant::Add(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = add i64{},%dc{}\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } DynamicConstant::Sub(left, right) => write!( body, " %dc{} = sub i64%dc{},%dc{}\n", @@ -619,13 +637,31 @@ impl<'a> CPUContext<'a> { left.idx(), right.idx() )?, - DynamicConstant::Mul(left, right) => write!( - body, - " %dc{} = mul i64%dc{},%dc{}\n", - dc.idx(), - left.idx(), - right.idx() - )?, + DynamicConstant::Mul(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = mul i64{},%dc{}\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } DynamicConstant::Div(left, right) => write!( body, " %dc{} = udiv i64%dc{},%dc{}\n", @@ -640,20 +676,56 @@ impl<'a> CPUContext<'a> { left.idx(), right.idx() )?, - DynamicConstant::Min(left, right) => write!( - body, - " %dc{} = call i64 @llvm.umin.i64(i64%dc{},i64%dc{})\n", - dc.idx(), - left.idx(), - right.idx() - )?, - DynamicConstant::Max(left, right) => write!( - body, - " %dc{} = call i64 @llvm.umax.i64(i64%dc{},i64%dc{})\n", - dc.idx(), - left.idx(), - right.idx() - )?, + DynamicConstant::Min(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = call i64 @llvm.umin.i64(i64{},i64%dc{}))\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } + DynamicConstant::Max(xs) => { + let mut xs = xs.iter().peekable(); + let mut cur_value = format!("%dc{}", xs.next().unwrap().idx()); + let mut idx = 0; + while let Some(x) = xs.next() { + let new_val = format!( + "%dc{}{}", + dc.idx(), + if xs.peek().is_some() { + format!(".{}", idx) + } else { + "".to_string() + } + ); + write!( + body, + " {} = call i64 @llvm.umax.i64(i64{},i64%dc{}))\n", + new_val, + cur_value, + x.idx() + )?; + cur_value = new_val; + idx += 1; + } + } } } Ok(()) diff --git a/hercules_cg/src/rt.rs b/hercules_cg/src/rt.rs index f97180ea24d3bf1b69810d6e79cf68fcb292e9fb..916d6520ae211aa0fe7a7f71f055f8acd0b01066 100644 --- a/hercules_cg/src/rt.rs +++ b/hercules_cg/src/rt.rs @@ -443,57 +443,89 @@ impl<'a> RTContext<'a> { id: DynamicConstantID, w: &mut W, ) -> Result<(), Error> { - match self.module.dynamic_constants[id.idx()] { + match &self.module.dynamic_constants[id.idx()] { DynamicConstant::Constant(val) => write!(w, "{}", val)?, DynamicConstant::Parameter(idx) => write!(w, "dc_p{}", idx)?, - DynamicConstant::Add(left, right) => { + DynamicConstant::Add(xs) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, "+")?; - self.codegen_dynamic_constant(right, w)?; + let mut xs = xs.iter(); + self.codegen_dynamic_constant(*xs.next().unwrap(), w)?; + for x in xs { + write!(w, "+")?; + self.codegen_dynamic_constant(*x, w)?; + } write!(w, ")")?; } DynamicConstant::Sub(left, right) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; + self.codegen_dynamic_constant(*left, w)?; write!(w, "-")?; - self.codegen_dynamic_constant(right, w)?; + self.codegen_dynamic_constant(*right, w)?; write!(w, ")")?; } - DynamicConstant::Mul(left, right) => { + DynamicConstant::Mul(xs) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, "*")?; - self.codegen_dynamic_constant(right, w)?; + let mut xs = xs.iter(); + self.codegen_dynamic_constant(*xs.next().unwrap(), w)?; + for x in xs { + write!(w, "*")?; + self.codegen_dynamic_constant(*x, w)?; + } write!(w, ")")?; } DynamicConstant::Div(left, right) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; + self.codegen_dynamic_constant(*left, w)?; write!(w, "/")?; - self.codegen_dynamic_constant(right, w)?; + self.codegen_dynamic_constant(*right, w)?; write!(w, ")")?; } DynamicConstant::Rem(left, right) => { write!(w, "(")?; - self.codegen_dynamic_constant(left, w)?; + self.codegen_dynamic_constant(*left, w)?; write!(w, "%")?; - self.codegen_dynamic_constant(right, w)?; + self.codegen_dynamic_constant(*right, w)?; write!(w, ")")?; } - DynamicConstant::Min(left, right) => { - write!(w, "::core::cmp::min(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, ",")?; - self.codegen_dynamic_constant(right, w)?; - write!(w, ")")?; + DynamicConstant::Min(xs) => { + let mut xs = xs.iter().peekable(); + + // Track the number of parentheses we open that need to be closed later + let mut opens = 0; + while let Some(x) = xs.next() { + if xs.peek().is_none() { + // For the last element, we just print it + self.codegen_dynamic_constant(*x, w)?; + } else { + // Otherwise, we create a new call to min and print the element as the + // first argument + write!(w, "::core::cmp::min(")?; + self.codegen_dynamic_constant(*x, w)?; + write!(w, ",")?; + opens += 1; + } + } + for _ in 0..opens { + write!(w, ")")?; + } } - DynamicConstant::Max(left, right) => { - write!(w, "::core::cmp::max(")?; - self.codegen_dynamic_constant(left, w)?; - write!(w, ",")?; - self.codegen_dynamic_constant(right, w)?; - write!(w, ")")?; + DynamicConstant::Max(xs) => { + let mut xs = xs.iter().peekable(); + + let mut opens = 0; + while let Some(x) = xs.next() { + if xs.peek().is_none() { + self.codegen_dynamic_constant(*x, w)?; + } else { + write!(w, "::core::cmp::max(")?; + self.codegen_dynamic_constant(*x, w)?; + write!(w, ",")?; + opens += 1; + } + } + for _ in 0..opens { + write!(w, ")")?; + } } } Ok(()) diff --git a/hercules_ir/Cargo.toml b/hercules_ir/Cargo.toml index deda9cc58758f6cc834aadcc8e4ec66625fefb4b..26950d4b7700d19326e6ea61aa2488b4c5d5df59 100644 --- a/hercules_ir/Cargo.toml +++ b/hercules_ir/Cargo.toml @@ -10,3 +10,4 @@ nom = "*" ordered-float = { version = "*", features = ["serde"] } bitvec = "*" serde = { version = "*", features = ["derive"] } +either = "*" diff --git a/hercules_ir/src/build.rs b/hercules_ir/src/build.rs index 1dd326c3ad1abf24e7bfa4aa1f28dfb8255af0e9..b804404524bb26e1e52c8f751bc416b7df84040d 100644 --- a/hercules_ir/src/build.rs +++ b/hercules_ir/src/build.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use crate::*; @@ -25,6 +26,23 @@ pub struct Builder<'a> { module: Module, } +impl<'a> DynamicConstantView for Builder<'a> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + &self.module.dynamic_constants[id.idx()] + } + + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.interned_dynamic_constants.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(self.module.dynamic_constants.len()); + self.module.dynamic_constants.push(dc.clone()); + self.interned_dynamic_constants.insert(dc, id); + id + } + } +} + /* * Since the builder doesn't provide string names for nodes, we need a different * mechanism for allowing one to allocate node IDs before actually creating the @@ -70,17 +88,6 @@ impl<'a> Builder<'a> { } } - fn intern_dynamic_constant(&mut self, dyn_cons: DynamicConstant) -> DynamicConstantID { - if let Some(id) = self.interned_dynamic_constants.get(&dyn_cons) { - *id - } else { - let id = DynamicConstantID::new(self.interned_dynamic_constants.len()); - self.interned_dynamic_constants.insert(dyn_cons.clone(), id); - self.module.dynamic_constants.push(dyn_cons); - id - } - } - pub fn add_label(&mut self, label: &String) -> LabelID { if let Some(id) = self.interned_labels.get(label) { *id @@ -406,11 +413,11 @@ impl<'a> Builder<'a> { } pub fn create_dynamic_constant_constant(&mut self, val: usize) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Constant(val)) + self.dc_const(val) } - pub fn create_dynamic_constant_parameter(&mut self, val: usize) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Parameter(val)) + pub fn create_dynamic_constant_parameter(&mut self, idx: usize) -> DynamicConstantID { + self.dc_param(idx) } pub fn create_dynamic_constant_add( @@ -418,7 +425,14 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Add(x, y)) + self.dc_add(vec![x, y]) + } + + pub fn create_dynamic_constant_add_many( + &mut self, + xs: Vec<DynamicConstantID>, + ) -> DynamicConstantID { + self.dc_add(xs) } pub fn create_dynamic_constant_sub( @@ -426,7 +440,7 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Sub(x, y)) + self.dc_sub(x, y) } pub fn create_dynamic_constant_mul( @@ -434,7 +448,14 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Mul(x, y)) + self.dc_mul(vec![x, y]) + } + + pub fn create_dynamic_constant_mul_many( + &mut self, + xs: Vec<DynamicConstantID>, + ) -> DynamicConstantID { + self.dc_mul(xs) } pub fn create_dynamic_constant_div( @@ -442,7 +463,7 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Div(x, y)) + self.dc_div(x, y) } pub fn create_dynamic_constant_rem( @@ -450,7 +471,7 @@ impl<'a> Builder<'a> { x: DynamicConstantID, y: DynamicConstantID, ) -> DynamicConstantID { - self.intern_dynamic_constant(DynamicConstant::Rem(x, y)) + self.dc_rem(x, y) } pub fn create_field_index(&self, idx: usize) -> Index { diff --git a/hercules_ir/src/dc_normalization.rs b/hercules_ir/src/dc_normalization.rs new file mode 100644 index 0000000000000000000000000000000000000000..e9f8f23aa7c05d47601c4a88087775ad82257ec0 --- /dev/null +++ b/hercules_ir/src/dc_normalization.rs @@ -0,0 +1,206 @@ +use crate::*; + +use std::cmp::{max, min}; +use std::collections::BTreeSet; +use std::ops::Deref; + +use either::Either; + +pub trait DynamicConstantView { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_; + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID; + + fn dc_const(&mut self, val: usize) -> DynamicConstantID { + self.add_dynconst(DynamicConstant::Constant(val)) + } + + fn dc_param(&mut self, index: usize) -> DynamicConstantID { + self.add_dynconst(DynamicConstant::Parameter(index)) + } + + fn dc_add(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val = 0; + let mut fields = vec![]; + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => constant_val += x, + DynamicConstant::Add(xs) => fields.extend_from_slice(xs), + _ => fields.push(dc), + } + } + + // If either there are no fields or the constant is non-zero, add it + if constant_val != 0 || fields.len() == 0 { + fields.push(self.add_dynconst(DynamicConstant::Constant(constant_val))); + } + + if fields.len() <= 1 { + // If there is only one term to add, just return it + fields[0] + } else { + fields.sort(); + self.add_dynconst(DynamicConstant::Add(fields)) + } + } + + fn dc_mul(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val = 1; + let mut fields = vec![]; + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => constant_val *= x, + DynamicConstant::Mul(xs) => fields.extend_from_slice(xs), + _ => fields.push(dc), + } + } + + if constant_val == 0 { + return self.add_dynconst(DynamicConstant::Constant(0)); + } + + if constant_val != 1 || fields.len() == 0 { + fields.push(self.add_dynconst(DynamicConstant::Constant(constant_val))); + } + + if fields.len() <= 1 { + fields[0] + } else { + fields.sort(); + self.add_dynconst(DynamicConstant::Mul(fields)) + } + } + + fn dc_min(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val: Option<usize> = None; + // For min and max we track the fields via a set during normalization as this removes + // duplicates (and we use a BTreeSet as it can produce its elements in sorted order) + let mut fields = BTreeSet::new(); + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => { + if let Some(cur_min) = constant_val { + constant_val = Some(min(cur_min, *x)); + } else { + constant_val = Some(*x); + } + } + DynamicConstant::Min(xs) => fields.extend(xs), + _ => { + fields.insert(dc); + } + } + } + + if let Some(const_val) = constant_val { + // Since dynamic constants are non-negative, ignore the constant if it is 0 + if const_val != 0 { + fields.insert(self.add_dynconst(DynamicConstant::Constant(const_val))); + } + } + + if fields.len() == 0 { + // The minimum of 0 dynamic constants is 0 since dynamic constants are non-negative + self.add_dynconst(DynamicConstant::Constant(0)) + } else if fields.len() <= 1 { + *fields.first().unwrap() + } else { + self.add_dynconst(DynamicConstant::Min(fields.into_iter().collect())) + } + } + + fn dc_max(&mut self, dcs: Vec<DynamicConstantID>) -> DynamicConstantID { + let mut constant_val: Option<usize> = None; + let mut fields = BTreeSet::new(); + + for dc in dcs { + match self.get_dynconst(dc).deref() { + DynamicConstant::Constant(x) => { + if let Some(cur_max) = constant_val { + constant_val = Some(max(cur_max, *x)); + } else { + constant_val = Some(*x); + } + } + DynamicConstant::Max(xs) => fields.extend(xs), + _ => { + fields.insert(dc); + } + } + } + + if let Some(const_val) = constant_val { + fields.insert(self.add_dynconst(DynamicConstant::Constant(const_val))); + } + + assert!( + fields.len() > 0, + "Max of 0 dynamic constant expressions is undefined" + ); + + if fields.len() <= 1 { + *fields.first().unwrap() + } else { + self.add_dynconst(DynamicConstant::Max(fields.into_iter().collect())) + } + } + + fn dc_sub(&mut self, x: DynamicConstantID, y: DynamicConstantID) -> DynamicConstantID { + let dc = match (self.get_dynconst(x).deref(), self.get_dynconst(y).deref()) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => { + Either::Left(DynamicConstant::Constant(x - y)) + } + (_, DynamicConstant::Constant(0)) => Either::Right(x), + _ => Either::Left(DynamicConstant::Sub(x, y)), + }; + + match dc { + Either::Left(dc) => self.add_dynconst(dc), + Either::Right(id) => id, + } + } + + fn dc_div(&mut self, x: DynamicConstantID, y: DynamicConstantID) -> DynamicConstantID { + let dc = match (self.get_dynconst(x).deref(), self.get_dynconst(y).deref()) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => { + Either::Left(DynamicConstant::Constant(x / y)) + } + (_, DynamicConstant::Constant(1)) => Either::Right(x), + _ => Either::Left(DynamicConstant::Div(x, y)), + }; + + match dc { + Either::Left(dc) => self.add_dynconst(dc), + Either::Right(id) => id, + } + } + + fn dc_rem(&mut self, x: DynamicConstantID, y: DynamicConstantID) -> DynamicConstantID { + let dc = match (self.get_dynconst(x).deref(), self.get_dynconst(y).deref()) { + (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => { + Either::Left(DynamicConstant::Constant(x % y)) + } + _ => Either::Left(DynamicConstant::Rem(x, y)), + }; + + match dc { + Either::Left(dc) => self.add_dynconst(dc), + Either::Right(id) => id, + } + } + + fn dc_normalize(&mut self, dc: DynamicConstant) -> DynamicConstantID { + match dc { + DynamicConstant::Add(xs) => self.dc_add(xs), + DynamicConstant::Mul(xs) => self.dc_mul(xs), + DynamicConstant::Min(xs) => self.dc_min(xs), + DynamicConstant::Max(xs) => self.dc_max(xs), + DynamicConstant::Sub(x, y) => self.dc_sub(x, y), + DynamicConstant::Div(x, y) => self.dc_div(x, y), + DynamicConstant::Rem(x, y) => self.dc_rem(x, y), + _ => self.add_dynconst(dc), + } + } +} diff --git a/hercules_ir/src/ir.rs b/hercules_ir/src/ir.rs index 04e699ea1b9f6106c9d538de583d1c9793110fa0..187b3f986b6ae7a62ec69e388b1627761880959f 100644 --- a/hercules_ir/src/ir.rs +++ b/hercules_ir/src/ir.rs @@ -1,4 +1,3 @@ -use std::cmp::{max, min}; use std::collections::HashSet; use std::fmt::Write; use std::ops::Coroutine; @@ -120,13 +119,14 @@ pub enum DynamicConstant { // function is this). Parameter(usize), // Supported integer operations on dynamic constants. - Add(DynamicConstantID, DynamicConstantID), + Add(Vec<DynamicConstantID>), + Mul(Vec<DynamicConstantID>), + Min(Vec<DynamicConstantID>), + Max(Vec<DynamicConstantID>), + Sub(DynamicConstantID, DynamicConstantID), - Mul(DynamicConstantID, DynamicConstantID), Div(DynamicConstantID, DynamicConstantID), Rem(DynamicConstantID, DynamicConstantID), - Min(DynamicConstantID, DynamicConstantID), - Max(DynamicConstantID, DynamicConstantID), } /* @@ -464,21 +464,31 @@ impl Module { match &self.dynamic_constants[dc_id.idx()] { DynamicConstant::Constant(cons) => write!(w, "{}", cons), DynamicConstant::Parameter(param) => write!(w, "#{}", param), - DynamicConstant::Add(x, y) - | DynamicConstant::Sub(x, y) - | DynamicConstant::Mul(x, y) + DynamicConstant::Add(xs) + | DynamicConstant::Mul(xs) + | DynamicConstant::Min(xs) + | DynamicConstant::Max(xs) => { + match &self.dynamic_constants[dc_id.idx()] { + DynamicConstant::Add(_) => write!(w, "+")?, + DynamicConstant::Mul(_) => write!(w, "*")?, + DynamicConstant::Min(_) => write!(w, "min")?, + DynamicConstant::Max(_) => write!(w, "max")?, + _ => (), + } + write!(w, "(")?; + for arg in xs { + self.write_dynamic_constant(*arg, w)?; + write!(w, ",")?; + } + write!(w, ")") + } + DynamicConstant::Sub(x, y) | DynamicConstant::Div(x, y) - | DynamicConstant::Rem(x, y) - | DynamicConstant::Min(x, y) - | DynamicConstant::Max(x, y) => { + | DynamicConstant::Rem(x, y) => { match &self.dynamic_constants[dc_id.idx()] { - DynamicConstant::Add(_, _) => write!(w, "+")?, DynamicConstant::Sub(_, _) => write!(w, "-")?, - DynamicConstant::Mul(_, _) => write!(w, "*")?, DynamicConstant::Div(_, _) => write!(w, "/")?, DynamicConstant::Rem(_, _) => write!(w, "%")?, - DynamicConstant::Min(_, _) => write!(w, "min")?, - DynamicConstant::Max(_, _) => write!(w, "max")?, _ => (), } write!(w, "(")?; @@ -639,15 +649,37 @@ pub fn dynamic_constants_bottom_up( if visited[id.idx()] { continue; } - match dynamic_constants[id.idx()] { - DynamicConstant::Add(left, right) - | DynamicConstant::Sub(left, right) - | DynamicConstant::Mul(left, right) - | DynamicConstant::Div(left, right) - | DynamicConstant::Rem(left, right) => { + match &dynamic_constants[id.idx()] { + DynamicConstant::Add(args) + | DynamicConstant::Mul(args) + | DynamicConstant::Min(args) + | DynamicConstant::Max(args) => { // We have to yield the children of this node before // this node itself. We keep track of which nodes have // yielded using visited. + if args + .iter() + .any(|i| i.idx() >= visited.len() || invalid[i.idx()]) + { + // This is an invalid dynamic constant and should be skipped + invalid.set(id.idx(), true); + continue; + } + + if args.iter().all(|i| visited[i.idx()]) { + // Since all children have been yielded, we yield ourself + visited.set(id.idx(), true); + yield id; + } else { + // Otherwise push self onto stack so that the children will get popped + // first + stack.push(id); + stack.extend(args.clone()); + } + } + DynamicConstant::Sub(left, right) + | DynamicConstant::Div(left, right) + | DynamicConstant::Rem(left, right) => { if left.idx() >= visited.len() || right.idx() >= visited.len() || invalid[left.idx()] @@ -664,8 +696,8 @@ pub fn dynamic_constants_bottom_up( // Push ourselves, then children, so that children // get popped first. stack.push(id); - stack.push(left); - stack.push(right); + stack.push(*left); + stack.push(*right); } } _ => { @@ -999,6 +1031,34 @@ impl Constant { } impl DynamicConstant { + pub fn add(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Add(vec![x, y]) + } + + pub fn sub(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Sub(x, y) + } + + pub fn mul(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Mul(vec![x, y]) + } + + pub fn div(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Div(x, y) + } + + pub fn rem(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Rem(x, y) + } + + pub fn min(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Min(vec![x, y]) + } + + pub fn max(x: DynamicConstantID, y: DynamicConstantID) -> Self { + Self::Max(vec![x, y]) + } + pub fn is_parameter(&self) -> bool { if let DynamicConstant::Parameter(_) = self { true @@ -1036,33 +1096,12 @@ pub fn evaluate_dynamic_constant( cons: DynamicConstantID, dcs: &Vec<DynamicConstant>, ) -> Option<usize> { - match dcs[cons.idx()] { - DynamicConstant::Constant(cons) => Some(cons), - DynamicConstant::Parameter(_) => None, - DynamicConstant::Add(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? + evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Sub(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? - evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Mul(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? * evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Div(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? / evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Rem(left, right) => { - Some(evaluate_dynamic_constant(left, dcs)? % evaluate_dynamic_constant(right, dcs)?) - } - DynamicConstant::Min(left, right) => Some(min( - evaluate_dynamic_constant(left, dcs)?, - evaluate_dynamic_constant(right, dcs)?, - )), - DynamicConstant::Max(left, right) => Some(max( - evaluate_dynamic_constant(left, dcs)?, - evaluate_dynamic_constant(right, dcs)?, - )), - } + // Because of normalization, if a dynamic constant can be expressed as a constant it must be a + // constant + let DynamicConstant::Constant(cons) = dcs[cons.idx()] else { + return None; + }; + Some(cons) } /* diff --git a/hercules_ir/src/lib.rs b/hercules_ir/src/lib.rs index 185a28886595330dc6e05e6f9286397d1f5a30d5..fc59a74c4b453d24deddc7141456bc0b21bb6e5d 100644 --- a/hercules_ir/src/lib.rs +++ b/hercules_ir/src/lib.rs @@ -11,6 +11,7 @@ pub mod build; pub mod callgraph; pub mod collections; pub mod dataflow; +pub mod dc_normalization; pub mod def_use; pub mod device; pub mod dom; @@ -28,6 +29,7 @@ pub use crate::build::*; pub use crate::callgraph::*; pub use crate::collections::*; pub use crate::dataflow::*; +pub use crate::dc_normalization::*; pub use crate::def_use::*; pub use crate::device::*; pub use crate::dom::*; diff --git a/hercules_ir/src/parse.rs b/hercules_ir/src/parse.rs index cdad54f935afd7b79eaf258b8ec0a83415ecb5ef..257dd4d998341feb2ad6e326a1e4b9e58b31c1e3 100644 --- a/hercules_ir/src/parse.rs +++ b/hercules_ir/src/parse.rs @@ -1,5 +1,6 @@ use std::cell::RefCell; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::str::FromStr; use crate::*; @@ -87,16 +88,7 @@ impl<'a> Context<'a> { } fn get_dynamic_constant_id(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID { - if let Some(id) = self.interned_dynamic_constants.get(&dynamic_constant) { - *id - } else { - let id = DynamicConstantID::new(self.interned_dynamic_constants.len()); - self.interned_dynamic_constants - .insert(dynamic_constant.clone(), id); - self.reverse_dynamic_constant_map - .insert(id, dynamic_constant); - id - } + self.dc_normalize(dynamic_constant) } fn get_label_id(&mut self, label: String) -> LabelID { @@ -110,6 +102,23 @@ impl<'a> Context<'a> { } } +impl<'a> DynamicConstantView for Context<'a> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + &self.reverse_dynamic_constant_map[&id] + } + + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.interned_dynamic_constants.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(self.reverse_dynamic_constant_map.len()); + self.interned_dynamic_constants.insert(dc.clone(), id); + self.reverse_dynamic_constant_map.insert(id, dc); + id + } + } +} + /* * A module is just a file with a list of functions. */ @@ -946,9 +955,9 @@ fn parse_dynamic_constant<'a>( ), )), |(op, (x, y))| match op { - '+' => DynamicConstant::Add(x, y), + '+' => DynamicConstant::Add(vec![x, y]), '-' => DynamicConstant::Sub(x, y), - '*' => DynamicConstant::Mul(x, y), + '*' => DynamicConstant::Mul(vec![x, y]), '/' => DynamicConstant::Div(x, y), '%' => DynamicConstant::Rem(x, y), _ => panic!("Invalid parse"), diff --git a/hercules_ir/src/typecheck.rs b/hercules_ir/src/typecheck.rs index a80dd422128bd3ba2ab6436272943ff1b2deb82f..f7ea397e49355029c7b5cbc0fa534494bc747c6d 100644 --- a/hercules_ir/src/typecheck.rs +++ b/hercules_ir/src/typecheck.rs @@ -1,6 +1,6 @@ -use std::cmp::{max, min}; use std::collections::HashMap; use std::iter::zip; +use std::ops::Deref; use crate::*; @@ -184,18 +184,20 @@ fn typeflow( dynamic_constants: &Vec<DynamicConstant>, num_parameters: u32, ) -> bool { - match dynamic_constants[root.idx()] { + match &dynamic_constants[root.idx()] { DynamicConstant::Constant(_) => true, - DynamicConstant::Parameter(idx) => idx < num_parameters as usize, - DynamicConstant::Add(x, y) - | DynamicConstant::Sub(x, y) - | DynamicConstant::Mul(x, y) + DynamicConstant::Parameter(idx) => *idx < num_parameters as usize, + DynamicConstant::Add(xs) + | DynamicConstant::Mul(xs) + | DynamicConstant::Min(xs) + | DynamicConstant::Max(xs) => xs + .iter() + .all(|dc| check_dynamic_constants(*dc, dynamic_constants, num_parameters)), + DynamicConstant::Sub(x, y) | DynamicConstant::Div(x, y) - | DynamicConstant::Rem(x, y) - | DynamicConstant::Min(x, y) - | DynamicConstant::Max(x, y) => { - check_dynamic_constants(x, dynamic_constants, num_parameters) - && check_dynamic_constants(y, dynamic_constants, num_parameters) + | DynamicConstant::Rem(x, y) => { + check_dynamic_constants(*x, dynamic_constants, num_parameters) + && check_dynamic_constants(*y, dynamic_constants, num_parameters) } } } @@ -733,10 +735,20 @@ fn typeflow( } } + // Construct the substitution object + let mut subst = DCSubst::new( + types, + reverse_type_map, + dynamic_constants, + reverse_dynamic_constant_map, + dc_args, + ); + // Check argument types. for (input, param_ty) in zip(inputs.iter().skip(1), callee.param_types.iter()) { + let param_ty = subst.type_subst(*param_ty); if let Concrete(input_id) = input { - if !types_match(types, dynamic_constants, dc_args, *param_ty, *input_id) { + if param_ty != *input_id { return Error(String::from( "Call node mismatches argument types with callee function.", )); @@ -747,14 +759,7 @@ fn typeflow( } } - Concrete(type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - callee.return_type, - )) + Concrete(subst.type_subst(callee.return_type)) } Node::IntrinsicCall { intrinsic, args: _ } => { let num_params = match intrinsic { @@ -1071,307 +1076,154 @@ pub fn cast_compatible(src_ty: &Type, dst_ty: &Type) -> bool { } /* - * Determine if the given type matches the parameter type when the provided - * dynamic constants are substituted in for the dynamic constants used in the - * parameter type. + * Data structures and methods for substituting given dynamic constant arguments into the provided + * types and dynamic constants */ -fn types_match( - types: &Vec<Type>, - dynamic_constants: &Vec<DynamicConstant>, - dc_args: &Box<[DynamicConstantID]>, - param: TypeID, - input: TypeID, -) -> bool { - // Note that we can't just check whether the type ids are equal since them - // being equal does not mean they match when we properly substitute in the - // dynamic constant arguments - - match (&types[param.idx()], &types[input.idx()]) { - (Type::Control, Type::Control) - | (Type::Boolean, Type::Boolean) - | (Type::Integer8, Type::Integer8) - | (Type::Integer16, Type::Integer16) - | (Type::Integer32, Type::Integer32) - | (Type::Integer64, Type::Integer64) - | (Type::UnsignedInteger8, Type::UnsignedInteger8) - | (Type::UnsignedInteger16, Type::UnsignedInteger16) - | (Type::UnsignedInteger32, Type::UnsignedInteger32) - | (Type::UnsignedInteger64, Type::UnsignedInteger64) - | (Type::Float32, Type::Float32) - | (Type::Float64, Type::Float64) => true, - (Type::Product(ps), Type::Product(is)) | (Type::Summation(ps), Type::Summation(is)) => { - ps.len() == is.len() - && ps - .iter() - .zip(is.iter()) - .all(|(p, i)| types_match(types, dynamic_constants, dc_args, *p, *i)) - } - (Type::Array(p, pds), Type::Array(i, ids)) => { - types_match(types, dynamic_constants, dc_args, *p, *i) - && pds.len() == ids.len() - && pds - .iter() - .zip(ids.iter()) - .all(|(pd, id)| dyn_consts_match(dynamic_constants, dc_args, *pd, *id)) - } - (_, _) => false, - } +struct DCSubst<'a> { + types: &'a mut Vec<Type>, + reverse_type_map: &'a mut HashMap<Type, TypeID>, + dynamic_constants: &'a mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &'a mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &'a [DynamicConstantID], } -/* - * Determine if the given dynamic constant matches the parameter's dynamic - * constants when the provided dynamic constants are substituted in for the - * dynamic constants used in the parameter's dynamic constant. Implement dynamic - * constant normalization here as well - i.e., 1 * 2 * 3 = 6. - */ -fn dyn_consts_match( - dynamic_constants: &Vec<DynamicConstant>, - dc_args: &Box<[DynamicConstantID]>, - left: DynamicConstantID, - right: DynamicConstantID, -) -> bool { - // First, try evaluating the DCs and seeing if they're the same value. - if let (Some(cons1), Some(cons2)) = ( - evaluate_dynamic_constant(left, dynamic_constants), - evaluate_dynamic_constant(right, dynamic_constants), - ) { - return cons1 == cons2; +impl<'a> DynamicConstantView for DCSubst<'a> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + &self.dynamic_constants[id.idx()] } - match ( - &dynamic_constants[left.idx()], - &dynamic_constants[right.idx()], - ) { - (DynamicConstant::Constant(x), DynamicConstant::Constant(y)) => x == y, - (DynamicConstant::Parameter(l), DynamicConstant::Parameter(r)) => l == r, - (DynamicConstant::Parameter(i), _) => dyn_consts_match( - dynamic_constants, - dc_args, - min(right, dc_args[*i]), - max(right, dc_args[*i]), - ), - (_, DynamicConstant::Parameter(i)) => dyn_consts_match( + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + if let Some(id) = self.reverse_dynamic_constant_map.get(&dc) { + *id + } else { + let id = DynamicConstantID::new(self.dynamic_constants.len()); + self.reverse_dynamic_constant_map.insert(dc.clone(), id); + self.dynamic_constants.push(dc); + id + } + } +} + +impl<'a> DCSubst<'a> { + fn new( + types: &'a mut Vec<Type>, + reverse_type_map: &'a mut HashMap<Type, TypeID>, + dynamic_constants: &'a mut Vec<DynamicConstant>, + reverse_dynamic_constant_map: &'a mut HashMap<DynamicConstant, DynamicConstantID>, + dc_args: &'a [DynamicConstantID], + ) -> Self { + Self { + types, + reverse_type_map, dynamic_constants, + reverse_dynamic_constant_map, dc_args, - min(left, dc_args[*i]), - max(left, dc_args[*i]), - ), - (DynamicConstant::Add(ll, lr), DynamicConstant::Add(rl, rr)) - | (DynamicConstant::Mul(ll, lr), DynamicConstant::Mul(rl, rr)) - | (DynamicConstant::Min(ll, lr), DynamicConstant::Min(rl, rr)) - | (DynamicConstant::Max(ll, lr), DynamicConstant::Max(rl, rr)) => { - // Normalize for associative ops by always looking at smaller DC ID - // as left arm and larger DC ID as right arm. - dyn_consts_match(dynamic_constants, dc_args, min(*ll, *lr), min(*rl, *rr)) - && dyn_consts_match(dynamic_constants, dc_args, max(*ll, *lr), max(*rl, *rr)) - } - (DynamicConstant::Sub(ll, lr), DynamicConstant::Sub(rl, rr)) - | (DynamicConstant::Div(ll, lr), DynamicConstant::Div(rl, rr)) - | (DynamicConstant::Rem(ll, lr), DynamicConstant::Rem(rl, rr)) => { - dyn_consts_match(dynamic_constants, dc_args, *ll, *rl) - && dyn_consts_match(dynamic_constants, dc_args, *lr, *rr) } - (_, _) => false, } -} -/* - * Substitutes the given dynamic constant arguments into the provided type and - * returns the appropriate typeID (potentially creating new types and dynamic - * constants in the process) - */ -fn type_subst( - types: &mut Vec<Type>, - dynamic_constants: &mut Vec<DynamicConstant>, - reverse_type_map: &mut HashMap<Type, TypeID>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - dc_args: &Box<[DynamicConstantID]>, - typ: TypeID, -) -> TypeID { - fn intern_type( - ty: Type, - types: &mut Vec<Type>, - reverse_type_map: &mut HashMap<Type, TypeID>, - ) -> TypeID { - if let Some(id) = reverse_type_map.get(&ty) { + fn intern_type(&mut self, ty: Type) -> TypeID { + if let Some(id) = self.reverse_type_map.get(&ty) { *id } else { - let id = TypeID::new(types.len()); - reverse_type_map.insert(ty.clone(), id); - types.push(ty); + let id = TypeID::new(self.types.len()); + self.reverse_type_map.insert(ty.clone(), id); + self.types.push(ty); id } } - match &types[typ.idx()] { - Type::Control - | Type::Boolean - | Type::Integer8 - | Type::Integer16 - | Type::Integer32 - | Type::Integer64 - | Type::UnsignedInteger8 - | Type::UnsignedInteger16 - | Type::UnsignedInteger32 - | Type::UnsignedInteger64 - | Type::Float32 - | Type::Float64 => typ, - Type::Product(ts) => { - let mut new_ts = vec![]; - for t in ts.clone().iter() { - new_ts.push(type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - *t, - )); + fn type_subst(&mut self, typ: TypeID) -> TypeID { + match &self.types[typ.idx()] { + Type::Control + | Type::Boolean + | Type::Integer8 + | Type::Integer16 + | Type::Integer32 + | Type::Integer64 + | Type::UnsignedInteger8 + | Type::UnsignedInteger16 + | Type::UnsignedInteger32 + | Type::UnsignedInteger64 + | Type::Float32 + | Type::Float64 => typ, + Type::Product(ts) => { + let new_ts = ts.clone().iter().map(|t| self.type_subst(*t)).collect(); + self.intern_type(Type::Product(new_ts)) } - intern_type(Type::Product(new_ts.into()), types, reverse_type_map) - } - Type::Summation(ts) => { - let mut new_ts = vec![]; - for t in ts.clone().iter() { - new_ts.push(type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - *t, - )); + Type::Summation(ts) => { + let new_ts = ts.clone().iter().map(|t| self.type_subst(*t)).collect(); + self.intern_type(Type::Summation(new_ts)) } - intern_type(Type::Summation(new_ts.into()), types, reverse_type_map) - } - Type::Array(elem, dims) => { - let ds = dims.clone(); - let new_elem = type_subst( - types, - dynamic_constants, - reverse_type_map, - reverse_dynamic_constant_map, - dc_args, - *elem, - ); - let mut new_dims = vec![]; - for d in ds.iter() { - new_dims.push(dyn_const_subst( - dynamic_constants, - reverse_dynamic_constant_map, - dc_args, - *d, - )); + Type::Array(elem, dims) => { + let elem = *elem; + let new_dims = dims + .clone() + .iter() + .map(|d| self.dyn_const_subst(*d)) + .collect(); + let new_elem = self.type_subst(elem); + self.intern_type(Type::Array(new_elem, new_dims)) } - intern_type( - Type::Array(new_elem, new_dims.into()), - types, - reverse_type_map, - ) - } - } -} - -fn dyn_const_subst( - dynamic_constants: &mut Vec<DynamicConstant>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - dc_args: &Box<[DynamicConstantID]>, - dyn_const: DynamicConstantID, -) -> DynamicConstantID { - fn intern_dyn_const( - dc: DynamicConstant, - dynamic_constants: &mut Vec<DynamicConstant>, - reverse_dynamic_constant_map: &mut HashMap<DynamicConstant, DynamicConstantID>, - ) -> DynamicConstantID { - if let Some(id) = reverse_dynamic_constant_map.get(&dc) { - *id - } else { - let id = DynamicConstantID::new(dynamic_constants.len()); - reverse_dynamic_constant_map.insert(dc.clone(), id); - dynamic_constants.push(dc); - id } } - match &dynamic_constants[dyn_const.idx()] { - DynamicConstant::Constant(_) => dyn_const, - DynamicConstant::Parameter(i) => dc_args[*i], - DynamicConstant::Add(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Add(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Sub(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Sub(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Mul(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Mul(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Div(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Div(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Rem(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Rem(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Min(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Min(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) - } - DynamicConstant::Max(l, r) => { - let x = *l; - let y = *r; - let sx = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, x); - let sy = dyn_const_subst(dynamic_constants, reverse_dynamic_constant_map, dc_args, y); - intern_dyn_const( - DynamicConstant::Max(sx, sy), - dynamic_constants, - reverse_dynamic_constant_map, - ) + fn dyn_const_subst(&mut self, dyn_const: DynamicConstantID) -> DynamicConstantID { + match &self.dynamic_constants[dyn_const.idx()] { + DynamicConstant::Constant(_) => dyn_const, + DynamicConstant::Parameter(i) => self.dc_args[*i], + DynamicConstant::Add(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_add(sxs) + } + DynamicConstant::Sub(l, r) => { + let x = *l; + let y = *r; + let sx = self.dyn_const_subst(x); + let sy = self.dyn_const_subst(y); + self.dc_sub(sx, sy) + } + DynamicConstant::Mul(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_mul(sxs) + } + DynamicConstant::Div(l, r) => { + let x = *l; + let y = *r; + let sx = self.dyn_const_subst(x); + let sy = self.dyn_const_subst(y); + self.dc_div(sx, sy) + } + DynamicConstant::Rem(l, r) => { + let x = *l; + let y = *r; + let sx = self.dyn_const_subst(x); + let sy = self.dyn_const_subst(y); + self.dc_rem(sx, sy) + } + DynamicConstant::Min(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_min(sxs) + } + DynamicConstant::Max(xs) => { + let sxs = xs + .clone() + .into_iter() + .map(|dc| self.dyn_const_subst(dc)) + .collect(); + self.dc_max(sxs) + } } } } diff --git a/hercules_opt/src/editor.rs b/hercules_opt/src/editor.rs index 39f1184cc947a35418641a817a86321343f101fc..1d5860574f4384887a334b78d060d2d4fd53e010 100644 --- a/hercules_opt/src/editor.rs +++ b/hercules_opt/src/editor.rs @@ -9,6 +9,7 @@ use either::Either; use hercules_ir::def_use::*; use hercules_ir::ir::*; +use hercules_ir::DynamicConstantView; /* * Helper object for editing Hercules functions in a trackable manner. Edits @@ -743,22 +744,7 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } pub fn add_dynamic_constant(&mut self, dynamic_constant: DynamicConstant) -> DynamicConstantID { - let pos = self - .editor - .dynamic_constants - .borrow() - .iter() - .chain(self.added_dynamic_constants.iter()) - .position(|c| *c == dynamic_constant); - if let Some(idx) = pos { - DynamicConstantID::new(idx) - } else { - let id = DynamicConstantID::new( - self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(), - ); - self.added_dynamic_constants.push(dynamic_constant); - id - } + self.dc_normalize(dynamic_constant) } pub fn get_dynamic_constant( @@ -788,6 +774,31 @@ impl<'a, 'b> FunctionEdit<'a, 'b> { } } +impl<'a, 'b> DynamicConstantView for FunctionEdit<'a, 'b> { + fn get_dynconst(&self, id: DynamicConstantID) -> impl Deref<Target = DynamicConstant> + '_ { + self.get_dynamic_constant(id) + } + + fn add_dynconst(&mut self, dc: DynamicConstant) -> DynamicConstantID { + let pos = self + .editor + .dynamic_constants + .borrow() + .iter() + .chain(self.added_dynamic_constants.iter()) + .position(|c| *c == dc); + if let Some(idx) = pos { + DynamicConstantID::new(idx) + } else { + let id = DynamicConstantID::new( + self.editor.dynamic_constants.borrow().len() + self.added_dynamic_constants.len(), + ); + self.added_dynamic_constants.push(dc); + id + } + } +} + #[cfg(test)] mod editor_tests { #[allow(unused_imports)] diff --git a/hercules_opt/src/fork_guard_elim.rs b/hercules_opt/src/fork_guard_elim.rs index 1abb89672ae1d5c4f0f34578ca9d8eb2d69a2bc0..052fd0e493327fceb3bd1b1918d4d4aafc93bf79 100644 --- a/hercules_opt/src/fork_guard_elim.rs +++ b/hercules_opt/src/fork_guard_elim.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use hercules_ir::*; @@ -70,22 +71,24 @@ fn guarded_fork( }; let mut factors = factors.iter().enumerate().map(|(idx, dc)| { - let DynamicConstant::Max(l, r) = *editor.get_dynamic_constant(*dc) else { + let factor = editor.get_dynamic_constant(*dc); + let DynamicConstant::Max(xs) = factor.deref() else { return Factor::Normal(*dc); }; - // There really needs to be a better way to work w/ associativity. - let binding = [(l, r), (r, l)]; - let id = binding.iter().find_map(|(a, b)| { - let DynamicConstant::Constant(1) = *editor.get_dynamic_constant(*a) else { - return None; - }; - Some(b) - }); - - match id { - Some(v) => Factor::Max(idx, *v), - None => Factor::Normal(*dc), + // Filter out any terms which are just 1s + let non_ones = xs.iter().filter(|i| { + if let DynamicConstant::Constant(1) = editor.get_dynamic_constant(**i).deref() { + false + } else { + true + } + }).collect::<Vec<_>>(); + // If we're left with just one term x, we had max { 1, x } + if non_ones.len() == 1 { + Factor::Max(idx, *non_ones[0]) + } else { + Factor::Normal(*dc) } }); diff --git a/hercules_opt/src/forkify.rs b/hercules_opt/src/forkify.rs index ce9ac1412f1253bff6589ec668db63725183ca6c..ec4e9fbcc22d9f1c8a53652173706b40c5b12e65 100644 --- a/hercules_opt/src/forkify.rs +++ b/hercules_opt/src/forkify.rs @@ -265,9 +265,8 @@ pub fn forkify_loop( let bound_dc_id = { let mut max_id = DynamicConstantID::new(0); editor.edit(|mut edit| { - // FIXME: Maybe add_dynamic_constant should intern? let one_id = edit.add_dynamic_constant(DynamicConstant::Constant(1)); - max_id = edit.add_dynamic_constant(DynamicConstant::Max(one_id, bound_dc_id)); + max_id = edit.add_dynamic_constant(DynamicConstant::max(one_id, bound_dc_id)); Ok(edit) }); max_id diff --git a/hercules_opt/src/gcm.rs b/hercules_opt/src/gcm.rs index 6d36e8ac2cd6c8fdff46f4e46b25a95e5b15db51..271bfaf1da55f6c8ab342d06853532eb5ce99fff 100644 --- a/hercules_opt/src/gcm.rs +++ b/hercules_opt/src/gcm.rs @@ -1064,9 +1064,9 @@ fn align(edit: &mut FunctionEdit, mut acc: DynamicConstantID, align: usize) -> D if align != 1 { let align_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align)); let align_m1_dc = edit.add_dynamic_constant(DynamicConstant::Constant(align - 1)); - acc = edit.add_dynamic_constant(DynamicConstant::Add(acc, align_m1_dc)); - acc = edit.add_dynamic_constant(DynamicConstant::Div(acc, align_dc)); - acc = edit.add_dynamic_constant(DynamicConstant::Mul(acc, align_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::add(acc, align_m1_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::div(acc, align_dc)); + acc = edit.add_dynamic_constant(DynamicConstant::mul(acc, align_dc)); } acc } @@ -1098,7 +1098,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> // the field. let field_size = type_size(edit, field, alignments); acc_size = align(edit, acc_size, alignments[field.idx()]); - acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, field_size)); + acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, field_size)); } // Finally, round up to the alignment of the whole product, since // the size needs to be a multiple of the alignment. @@ -1112,11 +1112,11 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> // Pick the size of the largest variant, since that's the most // memory we would need. let variant_size = type_size(edit, variant, alignments); - acc_size = edit.add_dynamic_constant(DynamicConstant::Max(acc_size, variant_size)); + acc_size = edit.add_dynamic_constant(DynamicConstant::max(acc_size, variant_size)); } // Add one byte for the discriminant and align the whole summation. let one = edit.add_dynamic_constant(DynamicConstant::Constant(1)); - acc_size = edit.add_dynamic_constant(DynamicConstant::Add(acc_size, one)); + acc_size = edit.add_dynamic_constant(DynamicConstant::add(acc_size, one)); acc_size = align(edit, acc_size, alignments[ty_id.idx()]); acc_size } @@ -1124,7 +1124,7 @@ fn type_size(edit: &mut FunctionEdit, ty_id: TypeID, alignments: &Vec<usize>) -> // The layout of an array is row-major linear in memory. let mut acc_size = type_size(edit, elem, alignments); for bound in bounds { - acc_size = edit.add_dynamic_constant(DynamicConstant::Mul(acc_size, bound)); + acc_size = edit.add_dynamic_constant(DynamicConstant::mul(acc_size, bound)); } acc_size } @@ -1160,7 +1160,7 @@ fn object_allocation( *total = align(&mut edit, *total, alignments[typing[id.idx()].idx()]); offsets.insert(id, *total); let type_size = type_size(&mut edit, typing[id.idx()], alignments); - *total = edit.add_dynamic_constant(DynamicConstant::Add(*total, type_size)); + *total = edit.add_dynamic_constant(DynamicConstant::add(*total, type_size)); } } Node::Call { @@ -1169,7 +1169,13 @@ fn object_allocation( ref dynamic_constants, args: _, } => { - let dynamic_constants = dynamic_constants.clone(); + let dynamic_constants = dynamic_constants.to_vec(); + let dc_args = (0..dynamic_constants.len()) + .map(|i| edit.add_dynamic_constant(DynamicConstant::Parameter(i))); + let substs = dc_args + .zip(dynamic_constants.into_iter()) + .collect::<HashMap<_, _>>(); + for device in BACKED_DEVICES { if let Some(mut callee_backing_size) = backing_allocations[&callee] .get(&device) @@ -1183,26 +1189,12 @@ fn object_allocation( offsets.insert(id, *total); // Substitute the dynamic constant parameters in the // callee's backing size. - let first_dc = edit.num_dynamic_constants() + 10000; - for (p_idx, dc_n) in zip(0..dynamic_constants.len(), first_dc..) { - let dc_a = - edit.add_dynamic_constant(DynamicConstant::Parameter(p_idx)); - callee_backing_size = substitute_dynamic_constants( - dc_a, - DynamicConstantID::new(dc_n), - callee_backing_size, - &mut edit, - ); - } - for (dc_n, dc_b) in zip(first_dc.., dynamic_constants.iter()) { - callee_backing_size = substitute_dynamic_constants( - DynamicConstantID::new(dc_n), - *dc_b, - callee_backing_size, - &mut edit, - ); - } - *total = edit.add_dynamic_constant(DynamicConstant::Add( + callee_backing_size = substitute_dynamic_constants( + &substs, + callee_backing_size, + &mut edit, + ); + *total = edit.add_dynamic_constant(DynamicConstant::add( *total, callee_backing_size, )); diff --git a/hercules_opt/src/inline.rs b/hercules_opt/src/inline.rs index 1d2bac97d848ace910d614a42743c6ea5fe3aa9e..848d957f1e25c37b448aea7b35b0f5c5100c6d69 100644 --- a/hercules_opt/src/inline.rs +++ b/hercules_opt/src/inline.rs @@ -119,7 +119,12 @@ fn inline_func( // Assemble all the info we'll need to do the edit. let dcs_a = &dc_param_idx_to_dc_id[..dynamic_constants.len()]; - let dcs_b = dynamic_constants.clone(); + let dcs_b = dynamic_constants.to_vec(); + let substs = dcs_a + .iter() + .map(|i| *i) + .zip(dcs_b.into_iter()) + .collect::<HashMap<_, _>>(); let args = args.clone(); let old_num_nodes = editor.func().nodes.len(); let old_id_to_new_id = |old_id: NodeID| NodeID::new(old_id.idx() + old_num_nodes); @@ -163,39 +168,7 @@ fn inline_func( || node.is_dynamic_constant() || node.is_call() { - // We have to perform the subsitution in two steps. First, - // we map every dynamic constant A to a non-sense dynamic - // constant ID. Second, we map each non-sense dynamic - // constant ID to the appropriate dynamic constant B. Why - // not just do this in one step from A to B? We update - // dynamic constants one at a time, so imagine the following - // A -> B mappings: - // ID 0 -> ID 1 - // ID 1 -> ID 0 - // First, we apply the first mapping. This changes all - // references to dynamic constant 0 to dynamic constant 1. - // Then, we apply the second mapping. This updates all - // already present references to dynamic constant 1, as well - // as the new references we just made in the first step. We - // actually want to institute all the updates - // *simultaneously*, hence the two step maneuver. - let first_dc = edit.num_dynamic_constants() + 10000; - for (dc_a, dc_n) in zip(dcs_a, first_dc..) { - substitute_dynamic_constants_in_node( - *dc_a, - DynamicConstantID::new(dc_n), - &mut node, - &mut edit, - ); - } - for (dc_n, dc_b) in zip(first_dc.., dcs_b.iter()) { - substitute_dynamic_constants_in_node( - DynamicConstantID::new(dc_n), - *dc_b, - &mut node, - &mut edit, - ); - } + substitute_dynamic_constants_in_node(&substs, &mut node, &mut edit); } let mut uses = get_uses_mut(&mut node); for u in uses.as_mut() { diff --git a/hercules_opt/src/interprocedural_sroa.rs b/hercules_opt/src/interprocedural_sroa.rs index f597cd80347d94a7c927d6fe085d80f843e280eb..f22c1fe8410bdb008dbffe4acb90fa1679f9e44e 100644 --- a/hercules_opt/src/interprocedural_sroa.rs +++ b/hercules_opt/src/interprocedural_sroa.rs @@ -313,31 +313,22 @@ fn compress_return_products(editors: &mut Vec<FunctionEditor>, all_callsites_edi // If this becomes a common pattern, it would be worth creating // a better abstraction around bulk replacement. - let new_dcs = (*dynamic_constants).clone(); + let new_dcs = (*dynamic_constants).to_vec(); + let old_dcs = dc_param_idx_to_dc_id[..new_dcs.len()].to_vec(); + assert_eq!(old_dcs.len(), new_dcs.len()); + let substs = old_dcs + .into_iter() + .zip(new_dcs.into_iter()) + .collect::<HashMap<_, _>>(); let edit_successful = editor.edit(|mut edit| { - let old_dcs = dc_param_idx_to_dc_id[..new_dcs.len()].to_vec().clone(); let mut substituted = old_return_type_ids[function_id.idx()]; - assert_eq!(old_dcs.len(), new_dcs.len()); - let first_dc = edit.num_dynamic_constants() + 10000; - for (dc_a, dc_n) in zip(old_dcs, first_dc..) { - substituted = substitute_dynamic_constants_in_type( - dc_a, - DynamicConstantID::new(dc_n), - substituted, - &mut edit, - ); - } - - for (dc_n, dc_b) in zip(first_dc.., new_dcs.iter()) { - substituted = substitute_dynamic_constants_in_type( - DynamicConstantID::new(dc_n), - *dc_b, - substituted, - &mut edit, - ); - } + let substituted = substitute_dynamic_constants_in_type( + &substs, + old_return_type_ids[function_id.idx()], + &mut edit, + ); let (expanded_product, readers) = uncompress_product(&mut edit, &call_node_id, &substituted); @@ -419,34 +410,26 @@ fn remove_return_singletons(editors: &mut Vec<FunctionEditor>, all_callsites_edi for call_node_id in call_node_ids { let (_, function, dc_args, _) = editor.func().nodes[call_node_id.idx()].try_call().unwrap(); - let dc_args = dc_args.clone(); + + let dc_args = dc_args.to_vec(); if singleton_removed[function.idx()] { let edit_successful = editor.edit(|mut edit| { - let mut substituted = old_return_type_ids[function.idx()]; - let first_dc = edit.num_dynamic_constants() + 10000; - let dc_params: Vec<_> = (0..dc_args.len()) + let dc_params = (0..dc_args.len()) .map(|param_idx| { edit.add_dynamic_constant(DynamicConstant::Parameter(param_idx)) }) - .collect(); - for (dc_a, dc_n) in zip(dc_params, first_dc..) { - substituted = substitute_dynamic_constants_in_type( - dc_a, - DynamicConstantID::new(dc_n), - substituted, - &mut edit, - ); - } - - for (dc_n, dc_b) in zip(first_dc.., dc_args.iter()) { - substituted = substitute_dynamic_constants_in_type( - DynamicConstantID::new(dc_n), - *dc_b, - substituted, - &mut edit, - ); - } + .collect::<Vec<_>>(); + let substs = dc_params + .into_iter() + .zip(dc_args.into_iter()) + .collect::<HashMap<_, _>>(); + + let substituted = substitute_dynamic_constants_in_type( + &substs, + old_return_type_ids[function.idx()], + &mut edit, + ); let empty_constant_id = edit.add_zero_constant(substituted); let empty_node_id = edit.add_node(Node::Constant { id: empty_constant_id, diff --git a/hercules_opt/src/lift_dc_math.rs b/hercules_opt/src/lift_dc_math.rs index afdb212064d84a0191f87ce366d67b7ea6728fa8..8256c889085a9b2902c6d4d5c8fd5a9fa2e77429 100644 --- a/hercules_opt/src/lift_dc_math.rs +++ b/hercules_opt/src/lift_dc_math.rs @@ -41,11 +41,11 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) { continue; }; match op { - BinaryOperator::Add => DynamicConstant::Add(left, right), - BinaryOperator::Sub => DynamicConstant::Sub(left, right), - BinaryOperator::Mul => DynamicConstant::Mul(left, right), - BinaryOperator::Div => DynamicConstant::Div(left, right), - BinaryOperator::Rem => DynamicConstant::Rem(left, right), + BinaryOperator::Add => DynamicConstant::add(left, right), + BinaryOperator::Sub => DynamicConstant::sub(left, right), + BinaryOperator::Mul => DynamicConstant::mul(left, right), + BinaryOperator::Div => DynamicConstant::div(left, right), + BinaryOperator::Rem => DynamicConstant::rem(left, right), _ => { continue; } @@ -64,8 +64,8 @@ pub fn lift_dc_math(editor: &mut FunctionEditor) { continue; }; match intrinsic { - Intrinsic::Min => DynamicConstant::Min(left, right), - Intrinsic::Max => DynamicConstant::Max(left, right), + Intrinsic::Min => DynamicConstant::min(left, right), + Intrinsic::Max => DynamicConstant::max(left, right), _ => { continue; } diff --git a/hercules_opt/src/utils.rs b/hercules_opt/src/utils.rs index 7ad48c1c09cc8542d1b521e3d8e12fe271ef1d39..2ab4e094a47f2ce1805b924560a51f30d12951d6 100644 --- a/hercules_opt/src/utils.rs +++ b/hercules_opt/src/utils.rs @@ -1,5 +1,4 @@ -use std::collections::HashMap; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::iter::zip; use hercules_ir::def_use::*; @@ -9,12 +8,11 @@ use nestify::nest; use crate::*; /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * type. Return the substituted version of the type, once memozied. + * Substitute all uses of dynamic constants in a type that are keys in the substs map with the + * dynamic constant value for that key. Return the substituted version of the type, once memoized. */ pub(crate) fn substitute_dynamic_constants_in_type( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, ty: TypeID, edit: &mut FunctionEdit, ) -> TypeID { @@ -24,7 +22,7 @@ pub(crate) fn substitute_dynamic_constants_in_type( Type::Product(ref fields) => { let new_fields = fields .into_iter() - .map(|field_id| substitute_dynamic_constants_in_type(dc_a, dc_b, *field_id, edit)) + .map(|field_id| substitute_dynamic_constants_in_type(substs, *field_id, edit)) .collect(); if new_fields != *fields { edit.add_type(Type::Product(new_fields)) @@ -35,9 +33,7 @@ pub(crate) fn substitute_dynamic_constants_in_type( Type::Summation(ref variants) => { let new_variants = variants .into_iter() - .map(|variant_id| { - substitute_dynamic_constants_in_type(dc_a, dc_b, *variant_id, edit) - }) + .map(|variant_id| substitute_dynamic_constants_in_type(substs, *variant_id, edit)) .collect(); if new_variants != *variants { edit.add_type(Type::Summation(new_variants)) @@ -46,10 +42,10 @@ pub(crate) fn substitute_dynamic_constants_in_type( } } Type::Array(elem_ty, ref dims) => { - let new_elem_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, elem_ty, edit); + let new_elem_ty = substitute_dynamic_constants_in_type(substs, elem_ty, edit); let new_dims = dims .into_iter() - .map(|dim_id| substitute_dynamic_constants(dc_a, dc_b, *dim_id, edit)) + .map(|dim_id| substitute_dynamic_constants(substs, *dim_id, edit)) .collect(); if new_elem_ty != elem_ty || new_dims != *dims { edit.add_type(Type::Array(new_elem_ty, new_dims)) @@ -62,107 +58,105 @@ pub(crate) fn substitute_dynamic_constants_in_type( } /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * dynamic constant C. Return the substituted version of C, once memoized. Takes - * a mutable edit instead of an editor since this may create new dynamic - * constants, which can only be done inside an edit. + * Substitute all uses of dynamic constants in a dynamic constant dc that are keys in the + * substs map and replace them with their appropriate replacement values. Return the substituted + * version of dc, once memoized. Takes a mutable edit instead of an editor since this may create + * new dynamic constants, which can only be done inside an edit. */ pub(crate) fn substitute_dynamic_constants( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, - dc_c: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, + dc: DynamicConstantID, edit: &mut FunctionEdit, ) -> DynamicConstantID { - // If C is just A, then just replace all of C with B. - if dc_a == dc_c { - return dc_b; + // If this dynamic constant should be substituted, just return the substitution + if let Some(subst) = substs.get(&dc) { + return *subst; } - // Since we substitute non-sense dynamic constant IDs earlier, we explicitly - // check that the provided ID to replace inside of is valid. Otherwise, - // ignore. - if dc_c.idx() >= edit.num_dynamic_constants() { - return dc_c; - } - - // If C is not just A, look inside of it to possibly substitute a child DC. - let dc_clone = edit.get_dynamic_constant(dc_c).clone(); + // Look inside the dynamic constant to perform substitution in its children + let dc_clone = edit.get_dynamic_constant(dc).clone(); match dc_clone { - DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc_c, - // This is a certified Rust moment. - DynamicConstant::Add(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Add(new_left, new_right)) + DynamicConstant::Constant(_) | DynamicConstant::Parameter(_) => dc, + DynamicConstant::Add(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Add(new_xs)) } else { - dc_c + dc } } DynamicConstant::Sub(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); + let new_left = substitute_dynamic_constants(substs, left, edit); + let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Sub(new_left, new_right)) } else { - dc_c + dc } } - DynamicConstant::Mul(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Mul(new_left, new_right)) + DynamicConstant::Mul(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Mul(new_xs)) } else { - dc_c + dc } } DynamicConstant::Div(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); + let new_left = substitute_dynamic_constants(substs, left, edit); + let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Div(new_left, new_right)) } else { - dc_c + dc } } DynamicConstant::Rem(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); + let new_left = substitute_dynamic_constants(substs, left, edit); + let new_right = substitute_dynamic_constants(substs, right, edit); if new_left != left || new_right != right { edit.add_dynamic_constant(DynamicConstant::Rem(new_left, new_right)) } else { - dc_c + dc } } - DynamicConstant::Min(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Min(new_left, new_right)) + DynamicConstant::Min(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Min(new_xs)) } else { - dc_c + dc } } - DynamicConstant::Max(left, right) => { - let new_left = substitute_dynamic_constants(dc_a, dc_b, left, edit); - let new_right = substitute_dynamic_constants(dc_a, dc_b, right, edit); - if new_left != left || new_right != right { - edit.add_dynamic_constant(DynamicConstant::Max(new_left, new_right)) + DynamicConstant::Max(xs) => { + let new_xs = xs + .iter() + .map(|x| substitute_dynamic_constants(substs, *x, edit)) + .collect::<Vec<_>>(); + if new_xs != xs { + edit.add_dynamic_constant(DynamicConstant::Max(new_xs)) } else { - dc_c + dc } } } } /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * constant. Return the substituted version of the constant, once memozied. + * Substitute all uses of the dynamic constants specified by the subst map in a constant. Return + * the substituted version of the constant, once memozied. */ pub(crate) fn substitute_dynamic_constants_in_constant( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, cons: ConstantID, edit: &mut FunctionEdit, ) -> ConstantID { @@ -170,12 +164,10 @@ pub(crate) fn substitute_dynamic_constants_in_constant( let cons_clone = edit.get_constant(cons).clone(); match cons_clone { Constant::Product(ty, fields) => { - let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); + let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); let new_fields = fields .iter() - .map(|field_id| { - substitute_dynamic_constants_in_constant(dc_a, dc_b, *field_id, edit) - }) + .map(|field_id| substitute_dynamic_constants_in_constant(substs, *field_id, edit)) .collect(); if new_ty != ty || new_fields != fields { edit.add_constant(Constant::Product(new_ty, new_fields)) @@ -184,8 +176,8 @@ pub(crate) fn substitute_dynamic_constants_in_constant( } } Constant::Summation(ty, idx, variant) => { - let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); - let new_variant = substitute_dynamic_constants_in_constant(dc_a, dc_b, variant, edit); + let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); + let new_variant = substitute_dynamic_constants_in_constant(substs, variant, edit); if new_ty != ty || new_variant != variant { edit.add_constant(Constant::Summation(new_ty, idx, new_variant)) } else { @@ -193,7 +185,7 @@ pub(crate) fn substitute_dynamic_constants_in_constant( } } Constant::Array(ty) => { - let new_ty = substitute_dynamic_constants_in_type(dc_a, dc_b, ty, edit); + let new_ty = substitute_dynamic_constants_in_type(substs, ty, edit); if new_ty != ty { edit.add_constant(Constant::Array(new_ty)) } else { @@ -205,12 +197,10 @@ pub(crate) fn substitute_dynamic_constants_in_constant( } /* - * Substitute all uses of a dynamic constant A with dynamic constant B in a - * node. + * Substitute all uses of the dynamic constants specified by the subst map in a node. */ pub(crate) fn substitute_dynamic_constants_in_node( - dc_a: DynamicConstantID, - dc_b: DynamicConstantID, + substs: &HashMap<DynamicConstantID, DynamicConstantID>, node: &mut Node, edit: &mut FunctionEdit, ) { @@ -220,14 +210,14 @@ pub(crate) fn substitute_dynamic_constants_in_node( factors, } => { for factor in factors.into_iter() { - *factor = substitute_dynamic_constants(dc_a, dc_b, *factor, edit); + *factor = substitute_dynamic_constants(substs, *factor, edit); } } Node::Constant { id } => { - *id = substitute_dynamic_constants_in_constant(dc_a, dc_b, *id, edit); + *id = substitute_dynamic_constants_in_constant(substs, *id, edit); } Node::DynamicConstant { id } => { - *id = substitute_dynamic_constants(dc_a, dc_b, *id, edit); + *id = substitute_dynamic_constants(substs, *id, edit); } Node::Call { control: _, @@ -236,7 +226,7 @@ pub(crate) fn substitute_dynamic_constants_in_node( args: _, } => { for dc_arg in dynamic_constants.into_iter() { - *dc_arg = substitute_dynamic_constants(dc_a, dc_b, *dc_arg, edit); + *dc_arg = substitute_dynamic_constants(substs, *dc_arg, edit); } } _ => {} diff --git a/hercules_test/hercules_interpreter/src/interpreter.rs b/hercules_test/hercules_interpreter/src/interpreter.rs index a78330e4f08075be053593b41dba0f412687f5f1..871e304a2f8fb285cc9d8c64d4aa62ec5eef3a1d 100644 --- a/hercules_test/hercules_interpreter/src/interpreter.rs +++ b/hercules_test/hercules_interpreter/src/interpreter.rs @@ -69,17 +69,17 @@ pub fn dyn_const_value( match dc { DynamicConstant::Constant(v) => *v, DynamicConstant::Parameter(v) => dyn_const_params[*v], - DynamicConstant::Add(a, b) => { - dyn_const_value(a, dyn_const_values, dyn_const_params) - + dyn_const_value(b, dyn_const_values, dyn_const_params) + DynamicConstant::Add(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(0, |s, v| s + v) } DynamicConstant::Sub(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) - dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Mul(a, b) => { - dyn_const_value(a, dyn_const_values, dyn_const_params) - * dyn_const_value(b, dyn_const_values, dyn_const_params) + DynamicConstant::Mul(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(1, |p, v| p * v) } DynamicConstant::Div(a, b) => { dyn_const_value(a, dyn_const_values, dyn_const_params) @@ -89,14 +89,28 @@ pub fn dyn_const_value( dyn_const_value(a, dyn_const_values, dyn_const_params) % dyn_const_value(b, dyn_const_values, dyn_const_params) } - DynamicConstant::Max(a, b) => max( - dyn_const_value(a, dyn_const_values, dyn_const_params), - dyn_const_value(b, dyn_const_values, dyn_const_params), - ), - DynamicConstant::Min(a, b) => min( - dyn_const_value(a, dyn_const_values, dyn_const_params), - dyn_const_value(b, dyn_const_values, dyn_const_params), - ), + DynamicConstant::Max(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(max(m, v)) + } else { + Some(v) + } + }) + .unwrap() + } + DynamicConstant::Min(xs) => { + xs.iter().map(|x| dyn_const_value(x, dyn_const_values, dyn_const_params)) + .fold(None, |m, v| { + if let Some(m) = m { + Some(min(m, v)) + } else { + Some(v) + } + }) + .unwrap() + } } } diff --git a/juno_frontend/src/dynconst.rs b/juno_frontend/src/dynconst.rs index defab822d19ebb99191ad5a7d387916247dcd7df..511dfa341e4bcb8100f5f6761e068f9a803257e3 100644 --- a/juno_frontend/src/dynconst.rs +++ b/juno_frontend/src/dynconst.rs @@ -291,16 +291,16 @@ impl DynConst { .map(|(d, c)| self.build_mono(builder, d, c)) .partition(|(_, neg)| !*neg); - let pos_sum = pos - .into_iter() - .map(|(t, _)| t) - .reduce(|x, y| builder.create_dynamic_constant_add(x, y)) - .unwrap_or_else(|| builder.create_dynamic_constant_constant(0)); + let pos_sum = + builder.create_dynamic_constant_add_many(pos.into_iter().map(|(t, _)| t).collect()); - let neg_sum = neg - .into_iter() - .map(|(t, _)| t) - .reduce(|x, y| builder.create_dynamic_constant_add(x, y)); + let neg_sum = if neg.is_empty() { + None + } else { + Some( + builder.create_dynamic_constant_add_many(neg.into_iter().map(|(t, _)| t).collect()), + ) + }; match neg_sum { None => pos_sum, @@ -317,72 +317,61 @@ impl DynConst { term: &Vec<i64>, coeff: &Ratio<i64>, ) -> (DynamicConstantID, bool) { - let term_id = term + let (pos, neg): (Vec<_>, Vec<_>) = term .iter() .enumerate() .filter(|(_, p)| **p != 0) .map(|(v, p)| self.build_power(builder, v, *p)) - .collect::<Vec<_>>() - .into_iter() - .reduce(|x, y| builder.create_dynamic_constant_mul(x, y)); - - match term_id { - None => { - // This means all powers of the term are 0, so we just - // output the coefficient - if !coeff.is_integer() { - panic!("Dynamic constant is a non-integer constant") - } else { - let val: i64 = coeff.to_integer(); - ( - builder.create_dynamic_constant_constant(val.abs() as usize), - val < 0, - ) - } - } - Some(term) => { - if coeff.is_one() { - (term, false) - } else { - let numer: i64 = coeff.numer().abs(); - let denom: i64 = *coeff.denom(); // > 0 - - let with_numer = if numer == 1 { - term - } else { - let numer_id = builder.create_dynamic_constant_constant(numer as usize); - builder.create_dynamic_constant_mul(numer_id, term) - }; - let with_denom = if denom == 1 { - with_numer - } else { - let denom_id = builder.create_dynamic_constant_constant(denom as usize); - builder.create_dynamic_constant_div(with_numer, denom_id) - }; - - (with_denom, numer < 0) - } + .partition(|(_, neg)| !*neg); + let mut pos: Vec<_> = pos.into_iter().map(|(t, _)| t).collect(); + let mut neg: Vec<_> = neg.into_iter().map(|(t, _)| t).collect(); + + let numerator = { + let numer: i64 = coeff.numer().abs(); + let numer_dc = builder.create_dynamic_constant_constant(numer as usize); + pos.push(numer_dc); + builder.create_dynamic_constant_mul_many(pos) + }; + + let denominator = { + let denom: i64 = *coeff.denom(); + assert!(denom > 0); + + if neg.is_empty() && denom == 1 { + None + } else { + let denom_dc = builder.create_dynamic_constant_constant(denom as usize); + neg.push(denom_dc); + Some(builder.create_dynamic_constant_mul_many(neg)) } + }; + + if let Some(denominator) = denominator { + ( + builder.create_dynamic_constant_div(numerator, denominator), + *coeff.numer() < 0, + ) + } else { + (numerator, *coeff.numer() < 0) } } // Build's a dynamic constant that is a certain power of a specific variable - fn build_power(&self, builder: &mut Builder, v: usize, power: i64) -> DynamicConstantID { + // Returns the dynamic constant id of variable raised to the absolute value of the power and a + // boolean indicating whether the power is actually negative + fn build_power( + &self, + builder: &mut Builder, + v: usize, + power: i64, + ) -> (DynamicConstantID, bool) { assert!(power != 0); let power_pos = power.abs() as usize; let var_id = builder.create_dynamic_constant_parameter(v); - let power_id = iter::repeat(var_id) - .take(power_pos) - .map(|_| var_id) - .reduce(|x, y| builder.create_dynamic_constant_mul(x, y)) - .expect("Power is non-zero"); + let power_id = + builder.create_dynamic_constant_mul_many((0..power_pos).map(|_| var_id).collect()); - if power > 0 { - power_id - } else { - let one_id = builder.create_dynamic_constant_constant(1); - builder.create_dynamic_constant_div(one_id, power_id) - } + (power_id, power < 0) } } diff --git a/juno_frontend/src/lang.l b/juno_frontend/src/lang.l index d54a54d773f1094c5e1355f3622ee861c10e64ac..94a12002131d8fb439b8e8cabc4cae4245d00ead 100644 --- a/juno_frontend/src/lang.l +++ b/juno_frontend/src/lang.l @@ -107,6 +107,7 @@ void "void" : ":" , "," +\.\. ".." \. "." ; ";" ~ "~" diff --git a/juno_frontend/src/lang.y b/juno_frontend/src/lang.y index e980773f18e816b1fe3b01995e71b00016c2f88b..b47186ff24ab91f4650adf6ffaed10b2ba0dd251 100644 --- a/juno_frontend/src/lang.y +++ b/juno_frontend/src/lang.y @@ -197,8 +197,10 @@ Pattern -> Result<Pattern, ()> Ok(Pattern::IntLit { span : span, base : base }) } | PackageName { Ok(Pattern::Variable { span : $span, name : $1? }) } | '(' PatternsComma ')' { Ok(Pattern::TuplePattern { span : $span, pats : $2? }) } - | PackageName '{' NamePatterns '}' - { Ok(Pattern::StructPattern { span : $span, name : $1?, pats : $3? }) } + | PackageName '{' StructPatterns '}' + { let (pats, ignore_other) = $3?; + let pats = pats.into_iter().collect(); + Ok(Pattern::StructPattern { span : $span, name : $1?, pats, ignore_other }) } | PackageName '(' PatternsComma ')' { Ok(Pattern::UnionPattern { span : $span, name : $1?, pats : $3? }) } ; @@ -211,13 +213,23 @@ PatternsCommaS -> Result<Vec<Pattern>, ()> : Pattern { Ok(vec![$1?]) } | PatternsCommaS ',' Pattern { flatten($1, $3) } ; -NamePatterns -> Result<Vec<(Id, Pattern)>, ()> - : 'ID' ':' Pattern { Ok(vec![(span_of_tok($1)?, $3?)]) } - | NamePatternsS ',' 'ID' ':' Pattern { flatten($1, res_pair(span_of_tok($3), $5)) } - ; -NamePatternsS -> Result<Vec<(Id, Pattern)>, ()> - : 'ID' ':' Pattern { Ok(vec![(span_of_tok($1)?, $3?)]) } - | NamePatternsS ',' 'ID' ':' Pattern { flatten($1, res_pair(span_of_tok($3), $5)) } +StructPatterns -> Result<(VecDeque<(Id, Pattern)>, bool), ()> + : { Ok((VecDeque::new(), false)) } + | '..' { Ok((VecDeque::new(), true)) } + | 'ID' { let span = span_of_tok($1)?; + let pattern = Pattern::Variable { span, name: vec![span] }; + Ok((VecDeque::from([(span, pattern)]), false)) } + | 'ID' ':' Pattern { let span = span_of_tok($1)?; + Ok((VecDeque::from([(span, $3?)]), false)) } + | 'ID' ',' StructPatterns { let span = span_of_tok($1)?; + let pattern = Pattern::Variable { span, name: vec![span] }; + let (mut fields, ignore_rest) = $3?; + fields.push_front((span, pattern)); + Ok((fields, ignore_rest)) } + | 'ID' ':' Pattern ',' StructPatterns { let span = span_of_tok($1)?; + let (mut fields, ignore_rest) = $5?; + fields.push_front((span, $3?)); + Ok((fields, ignore_rest)) } ; Stmt -> Result<Stmt, ()> @@ -683,7 +695,8 @@ pub enum Pattern { IntLit { span : Span, base : IntBase }, Variable { span : Span, name : PackageName }, TuplePattern { span : Span, pats : Vec<Pattern> }, - StructPattern { span : Span, name : PackageName, pats : Vec<(Id, Pattern)> }, + // Ignore other indicates the pattern ended with .. and so there may be other fields that were not listed + StructPattern { span : Span, name : PackageName, pats : Vec<(Id, Pattern)>, ignore_other: bool }, UnionPattern { span : Span, name : PackageName, pats : Vec<Pattern> }, } diff --git a/juno_frontend/src/semant.rs b/juno_frontend/src/semant.rs index 2fe4bf88278c2478749a0ed67451d5371c32d847..e133e3c20b7590eb372756cfdbce1f732d57d4f6 100644 --- a/juno_frontend/src/semant.rs +++ b/juno_frontend/src/semant.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, LinkedList}; +use std::collections::{HashMap, HashSet, LinkedList}; use std::fmt; use std::fs::File; use std::io::Read; @@ -720,52 +720,49 @@ fn analyze_program( } // Process arguments - let mut arg_types: Vec<(usize, Type, bool)> = vec![]; // list of name, type, and - // whether is inout - let mut inout_args = vec![]; // list of indices into args + // We collect the list of the argument types, whether they are inout, and their + // unique variable number + let mut arg_info: Vec<(Type, bool, usize)> = vec![]; + // We collect the list of the types and variable numbers of the inout arguments + let mut inouts: Vec<(Type, usize)> = vec![]; // A collection of errors we encounter processing the arguments let mut errors = LinkedList::new(); + // Any statements that need to go at the beginning of the function to handle + // patterns in the arguments + let mut stmts = vec![]; for (inout, VarBind { span, pattern, typ }) in args { let typ = typ.unwrap_or(parser::Type::WildType { span }); - match pattern { - Pattern::Variable { span, name } => { - if name.len() != 1 { - errors.push_back( - ErrorMessage::SemanticError( - span_to_loc(span, lexer), - "Bound variables must be local names, without a package separator".to_string())); - continue; + match process_type( + typ, + num_dyn_const, + lexer, + &mut stringtab, + &env, + &mut types, + true, + ) { + Ok(ty) => { + let var = env.uniq(); + + if inout.is_some() { + inouts.push((ty, var)); } + arg_info.push((ty, inout.is_some(), var)); - let nm = intern_package_name(&name, lexer, &mut stringtab)[0]; - match process_type( - typ, - num_dyn_const, - lexer, - &mut stringtab, - &env, - &mut types, - true, - ) { - Ok(ty) => { - if inout.is_some() { - inout_args.push(arg_types.len()); - } - arg_types.push((nm, ty, inout.is_some())); + match process_irrefutable_pattern(pattern, false, var, ty, lexer, &mut stringtab, &mut env, &mut types) { + Ok(prep) => { + stmts.extend(prep); } Err(mut errs) => { errors.append(&mut errs); } } } - _ => { - errors.push_back(ErrorMessage::NotImplemented( - span_to_loc(span, lexer), - "patterns in arguments".to_string(), - )); + Err(mut errs) => { + errors.append(&mut errs); } } } @@ -798,34 +795,12 @@ fn analyze_program( } // Compute the proper type accounting for the inouts (which become returns) - let mut inout_types = vec![]; - for arg_idx in &inout_args { - inout_types.push(arg_types[*arg_idx].1.clone()); - } + let mut inout_types = inouts.iter().map(|(t, _)| *t).collect::<Vec<_>>(); let inout_tuple = types.new_tuple(inout_types.clone()); let pure_return_type = types.new_tuple(vec![return_type, inout_tuple]); - // Add the arguments to the environment and assign each a unique variable number - // Also track the variable numbers of the inout arguments for generating returns - let mut arg_variables = vec![]; - let mut inout_variables = vec![]; - for (nm, ty, is_inout) in arg_types.iter() { - let variable = env.uniq(); - env.insert( - *nm, - Entity::Variable { - variable: variable, - typ: *ty, - is_const: false, - }, - ); - arg_variables.push(variable); - - if *is_inout { - inout_variables.push(variable); - } - } + let inout_variables = inouts.iter().map(|(_, v)| *v).collect::<Vec<_>>(); // Finally, we have a properly built environment and we can // start processing the body @@ -871,6 +846,10 @@ fn analyze_program( } } + // Add the code for initializing arguments + stmts.push(body); + body = Stmt::BlockStmt { body: stmts }; + env.close_scope(); // Add the function to the global environment @@ -880,9 +859,9 @@ fn analyze_program( Entity::Function { index: res.len(), type_args: type_kinds, - args: arg_types + args: arg_info .iter() - .map(|(_, ty, is)| (*ty, *is)) + .map(|(ty, is, _)| (*ty, *is)) .collect::<Vec<_>>(), return_type: return_type, }, @@ -893,10 +872,9 @@ fn analyze_program( name: lexer.span_str(name).to_string(), num_dyn_consts: num_dyn_const, num_type_args: num_type_var, - arguments: arg_types + arguments: arg_info .iter() - .zip(arg_variables.iter()) - .map(|(v, n)| (*n, v.1)) + .map(|(t, _, v)| (*v, *t)) .collect::<Vec<_>>(), return_type: pure_return_type, body: body, @@ -1626,7 +1604,7 @@ fn process_stmt( ) -> Result<(Stmt, bool), ErrorMessages> { match stmt { parser::Stmt::LetStmt { - span: _, + span, var: VarBind { span: v_span, @@ -1634,82 +1612,66 @@ fn process_stmt( typ, }, init, - } => match pattern { - Pattern::Variable { span, name } => { - if typ.is_none() && init.is_none() { - Err(singleton_error(ErrorMessage::SemanticError( - span_to_loc(span, lexer), - "Must specify either type or initial value".to_string(), - )))? - } - if name.len() != 1 { - Err(singleton_error(ErrorMessage::SemanticError( - span_to_loc(span, lexer), - "Bound variables must be local names, without a package separator" - .to_string(), - )))? - } - - let nm = intern_package_name(&name, lexer, stringtab)[0]; - let ty = match typ { - None => None, - Some(t) => Some(process_type( - t, - num_dyn_const, - lexer, - stringtab, - env, - types, - true, - )?), - }; + } => { + if typ.is_none() && init.is_none() { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Must specify either type or initial value".to_string(), + ))); + } - let var = env.uniq(); + let ty = match typ { + None => None, + Some(t) => Some(process_type( + t, + num_dyn_const, + lexer, + stringtab, + env, + types, + true, + )?), + }; - let (val, exp_loc) = match init { - Some(exp) => { - let loc = span_to_loc(exp.span(), lexer); - ( - process_expr(exp, num_dyn_const, lexer, stringtab, env, types)?, - loc, - ) - } - None => ( - Expr::Zero { - typ: ty.expect("From Above"), - }, - Location::fake(), - ), - }; - let typ = val.get_type(); + let var = env.uniq(); - env.insert( - nm, - Entity::Variable { - variable: var, - typ: typ, - is_const: false, + let (val, exp_loc) = match init { + Some(exp) => { + let loc = span_to_loc(exp.span(), lexer); + ( + process_expr(exp, num_dyn_const, lexer, stringtab, env, types)?, + loc, + ) + } + None => ( + Expr::Zero { + typ: ty.expect("From Above"), }, - ); + Location::fake(), + ), + }; + let typ = val.get_type(); - match ty { - Some(ty) if !types.unify(ty, typ) => { - Err(singleton_error(ErrorMessage::TypeError( - exp_loc, - unparse_type(types, ty, stringtab), - unparse_type(types, typ, stringtab), - )))? - } - _ => Ok((Stmt::AssignStmt { var: var, val: val }, true)), + if let Some(ty) = ty { + if !types.unify(ty, typ) { + return Err(singleton_error(ErrorMessage::TypeError( + exp_loc, + unparse_type(types, ty, stringtab), + unparse_type(types, typ, stringtab), + ))); } } - _ => Err(singleton_error(ErrorMessage::NotImplemented( - span_to_loc(v_span, lexer), - "non-variable bindings".to_string(), - ))), - }, + + let mut res = vec![]; + res.push(Stmt::AssignStmt { var, val }); + res.extend(process_irrefutable_pattern( + pattern, false, var, typ, lexer, stringtab, env, types, + )?); + + Ok((Stmt::BlockStmt { body: res }, true)) + } parser::Stmt::ConstStmt { - span: _, + span, var: VarBind { span: v_span, @@ -1717,80 +1679,64 @@ fn process_stmt( typ, }, init, - } => match pattern { - Pattern::Variable { span, name } => { - if typ.is_none() && init.is_none() { - Err(singleton_error(ErrorMessage::SemanticError( - span_to_loc(span, lexer), - "Must specify either type or initial value".to_string(), - )))? - } - if name.len() != 1 { - Err(singleton_error(ErrorMessage::SemanticError( - span_to_loc(span, lexer), - "Bound variables must be local names, without a package separator" - .to_string(), - )))? - } - - let nm = intern_package_name(&name, lexer, stringtab)[0]; - let ty = match typ { - None => None, - Some(t) => Some(process_type( - t, - num_dyn_const, - lexer, - stringtab, - env, - types, - true, - )?), - }; + } => { + if typ.is_none() && init.is_none() { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Must specify either type or initial value".to_string(), + ))); + } - let var = env.uniq(); + let ty = match typ { + None => None, + Some(t) => Some(process_type( + t, + num_dyn_const, + lexer, + stringtab, + env, + types, + true, + )?), + }; - let (val, exp_loc) = match init { - Some(exp) => { - let loc = span_to_loc(exp.span(), lexer); - ( - process_expr(exp, num_dyn_const, lexer, stringtab, env, types)?, - loc, - ) - } - None => ( - Expr::Zero { - typ: ty.expect("From Above"), - }, - Location::fake(), - ), - }; - let typ = val.get_type(); + let var = env.uniq(); - env.insert( - nm, - Entity::Variable { - variable: var, - typ: typ, - is_const: true, + let (val, exp_loc) = match init { + Some(exp) => { + let loc = span_to_loc(exp.span(), lexer); + ( + process_expr(exp, num_dyn_const, lexer, stringtab, env, types)?, + loc, + ) + } + None => ( + Expr::Zero { + typ: ty.expect("From Above"), }, - ); + Location::fake(), + ), + }; + let typ = val.get_type(); - match ty { - Some(ty) if !types.unify(ty, typ) => { - Err(singleton_error(ErrorMessage::TypeError( - exp_loc, - unparse_type(types, ty, stringtab), - unparse_type(types, typ, stringtab), - ))) - } - _ => Ok((Stmt::AssignStmt { var: var, val: val }, true)), + if let Some(ty) = ty { + if !types.unify(ty, typ) { + return Err(singleton_error(ErrorMessage::TypeError( + exp_loc, + unparse_type(types, ty, stringtab), + unparse_type(types, typ, stringtab), + ))); } } - _ => Err(singleton_error(ErrorMessage::NotImplemented( - span_to_loc(v_span, lexer), - "non-variable bindings".to_string(), - ))), - }, + + let mut res = vec![]; + res.push(Stmt::AssignStmt { var, val }); + res.extend(process_irrefutable_pattern( + pattern, false, var, typ, lexer, stringtab, env, types, + )?); + + Ok((Stmt::BlockStmt { body: res }, true)) + } parser::Stmt::AssignStmt { span: _, lhs, @@ -2070,9 +2016,9 @@ fn process_stmt( (var, nm, var_type) } _ => { - return Err(singleton_error(ErrorMessage::NotImplemented( + return Err(singleton_error(ErrorMessage::SemanticError( span_to_loc(v_span, lexer), - "patterns in for loop arguments".to_string(), + "for loop index must be a variable".to_string(), ))); } }; @@ -5107,3 +5053,243 @@ fn convert_primitive(prim: parser::Primitive) -> types::Primitive { parser::Primitive::Void => types::Primitive::Unit, } } + +// Processes an irrefutable pattern by extracting the pieces from the given variable which has the +// given type. Adds any variables in that pattern to the environment and returns a list of +// statements that handle the pattern +fn process_irrefutable_pattern( + pat: parser::Pattern, + is_const: bool, + var: usize, + typ: Type, + lexer: &dyn NonStreamingLexer<DefaultLexerTypes<u32>>, + stringtab: &mut StringTable, + env: &mut Env<usize, Entity>, + types: &mut TypeSolver, +) -> Result<Vec<Stmt>, ErrorMessages> { + match pat { + Pattern::Wildcard { .. } => Ok(vec![]), + Pattern::Variable { span, name } => { + if name.len() != 1 { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Bound variables must be local names, without a package separator".to_string(), + ))); + } + + let nm = intern_package_name(&name, lexer, stringtab)[0]; + let variable = env.uniq(); + env.insert( + nm, + Entity::Variable { + variable, + typ, + is_const, + }, + ); + + Ok(vec![Stmt::AssignStmt { + var: variable, + val: Expr::Variable { var, typ }, + }]) + } + Pattern::TuplePattern { span, pats } => { + let Some(fields) = types.get_fields(typ) else { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!( + "Type {} is not a tuple", + unparse_type(types, typ, stringtab), + ), + ))); + }; + let fields = fields.clone(); + + if fields.len() != pats.len() { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!( + "Expected {} fields, pattern has {}", + fields.len(), + pats.len() + ), + ))); + } + + let mut res = vec![]; + let mut errors = LinkedList::new(); + + for (idx, (pat, field)) in pats.into_iter().zip(fields.into_iter()).enumerate() { + // Extract this field from the current value + let variable = env.uniq(); + res.push(Stmt::AssignStmt { + var: variable, + val: Expr::Read { + index: vec![Index::Field(idx)], + val: Box::new(Expr::Variable { var, typ }), + typ: field, + }, + }); + + match process_irrefutable_pattern( + pat, is_const, variable, field, lexer, stringtab, env, types, + ) { + Ok(stmts) => res.extend(stmts), + Err(errs) => errors.extend(errs), + } + } + + if errors.is_empty() { + Ok(res) + } else { + Err(errors) + } + } + Pattern::StructPattern { + span, + name, + pats, + ignore_other, + } => { + if name.len() != 1 { + return Err(singleton_error(ErrorMessage::NotImplemented( + span_to_loc(span, lexer), + "packages".to_string(), + ))); + } + + let struct_nm = intern_package_name(&name, lexer, stringtab)[0]; + match env.lookup(&struct_nm) { + Some(Entity::Variable { .. }) => Err(singleton_error(ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "variable".to_string(), + ))), + Some(Entity::DynConst { .. }) => Err(singleton_error(ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "dynamic constant".to_string(), + ))), + Some(Entity::Constant { .. }) => Err(singleton_error(ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "constant".to_string(), + ))), + Some(Entity::Function { .. }) => Err(singleton_error(ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "function".to_string(), + ))), + None => Err(singleton_error(ErrorMessage::UndefinedVariable( + span_to_loc(span, lexer), + stringtab.lookup_id(struct_nm).unwrap(), + ))), + Some(Entity::Type { + type_args: _, + value: struct_typ, + }) => { + let struct_typ = *struct_typ; + + if !types.is_struct(struct_typ) { + return Err(singleton_error(ErrorMessage::KindError( + span_to_loc(span, lexer), + "struct name".to_string(), + "non-struct type".to_string(), + ))); + } + + if !types.unify(typ, struct_typ) { + return Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!( + "Expected a pattern for type {} but found pattern for type {}", + unparse_type(types, typ, stringtab), + unparse_type(types, struct_typ, stringtab) + ), + ))); + } + + // Fields that have already been used + let mut unused_fields = types + .get_field_names(struct_typ) + .unwrap() + .into_iter() + .collect::<HashSet<_>>(); + let mut res = vec![]; + let mut errors = LinkedList::new(); + + for (field_name, pat) in pats { + let field_nm = intern_id(&field_name, lexer, stringtab); + match types.get_field(struct_typ, field_nm) { + None => { + errors.push_back(ErrorMessage::SemanticError( + span_to_loc(field_name, lexer), + format!( + "Struct {} does not have field {}", + unparse_type(types, struct_typ, stringtab), + stringtab.lookup_id(field_nm).unwrap() + ), + )); + } + Some((idx, field_typ)) => { + if !unused_fields.contains(&field_nm) { + errors.push_back(ErrorMessage::SemanticError( + span_to_loc(field_name, lexer), + format!( + "Field {} appears multiple times in pattern", + stringtab.lookup_id(field_nm).unwrap() + ), + )); + } else { + unused_fields.remove(&field_nm); + let variable = env.uniq(); + res.push(Stmt::AssignStmt { + var: variable, + val: Expr::Read { + index: vec![Index::Field(idx)], + val: Box::new(Expr::Variable { var, typ }), + typ: field_typ, + }, + }); + match process_irrefutable_pattern( + pat, is_const, variable, field_typ, lexer, stringtab, env, + types, + ) { + Ok(stmts) => res.extend(stmts), + Err(errs) => errors.extend(errs), + } + } + } + } + } + + if !unused_fields.is_empty() && !ignore_other { + errors.push_back(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + format!( + "Pattern is missing fields: {}", + unused_fields + .into_iter() + .map(|i| stringtab.lookup_id(i).unwrap()) + .collect::<Vec<_>>() + .join(", ") + ), + )); + } + + if !errors.is_empty() { + Err(errors) + } else { + Ok(res) + } + } + } + } + Pattern::IntLit { span, .. } | Pattern::UnionPattern { span, .. } => { + Err(singleton_error(ErrorMessage::SemanticError( + span_to_loc(span, lexer), + "Expected an irrefutable pattern, but pattern is refutable".to_string(), + ))) + } + } +} diff --git a/juno_frontend/src/types.rs b/juno_frontend/src/types.rs index 5f907cd9370a343eb501e0c4b5b1dcef2a55f651..d4d8b23326fcf624af05eb3c8dde89e8159eeba8 100644 --- a/juno_frontend/src/types.rs +++ b/juno_frontend/src/types.rs @@ -556,15 +556,16 @@ impl TypeSolver { _ => None, } } + */ - fn get_fields(&self, Type { val } : Type) -> Vec<Type> { - match &self.types[val] { - TypeForm::Tuple(fields) => { fields.clone() }, - TypeForm::OtherType(t) => self.get_fields(*t), - _ => panic!("Internal function get_fields used on non-tuple"), - } + // Returns the types of the fields of a tuple + pub fn get_fields(&self, Type { val }: Type) -> Option<&Vec<Type>> { + match &self.types[val] { + TypeForm::Tuple { fields, .. } => Some(fields), + TypeForm::OtherType { other, .. } => self.get_fields(*other), + _ => None, } - */ + } // Return the type of the field (in a tuple) at a particular index pub fn get_index(&self, Type { val }: Type, idx: usize) -> Option<Type> { @@ -626,17 +627,18 @@ impl TypeSolver { } } - /* - pub fn get_field_names(&self, Type { val } : Type) -> Option<Vec<usize>> { - match &self.types[val] { - TypeForm::Struct { name : _, id : _, fields : _, names } => { - Some(names.keys().map(|i| *i).collect::<Vec<_>>()) - }, - TypeForm::OtherType(t) => self.get_field_names(*t), - _ => None, - } + pub fn get_field_names(&self, Type { val }: Type) -> Option<Vec<usize>> { + match &self.types[val] { + TypeForm::Struct { + name: _, + id: _, + fields: _, + names, + } => Some(names.keys().map(|i| *i).collect::<Vec<_>>()), + TypeForm::OtherType { other, .. } => self.get_field_names(*other), + _ => None, } - */ + } pub fn get_num_dimensions(&self, Type { val }: Type) -> Option<usize> { match &self.types[val] { diff --git a/juno_samples/antideps/src/antideps.jn b/juno_samples/antideps/src/antideps.jn index f40640d2d7fdec15f3b4ec18d753b01ae145cb23..85c8b9d487175d01c51f6c4b2c1dfb4d756be86e 100644 --- a/juno_samples/antideps/src/antideps.jn +++ b/juno_samples/antideps/src/antideps.jn @@ -114,7 +114,7 @@ fn very_complex_antideps(x: usize) -> usize { #[entry] fn read_chains(input : i32) -> i32 { let arrs : (i32[2], i32[2]); - let sub = arrs.0; + let (sub, _) = arrs; sub[1] = input + 7; arrs.0[1] = input + 3; let result = sub[1] + arrs.0[1]; diff --git a/juno_samples/concat/src/concat.jn b/juno_samples/concat/src/concat.jn index d901e7e17c527edcbcb01929b14adfcb37e6c642..01a549690336be138f41ce4ba5f5ca2e569e399f 100644 --- a/juno_samples/concat/src/concat.jn +++ b/juno_samples/concat/src/concat.jn @@ -22,3 +22,22 @@ fn concat_entry<a : usize, b: usize>(arr1 : i32[a], arr2 : i32[b]) -> i32 { let arr3 = concat::<i32, a, b>(arr1, arr2); return sum::<i32, a + b>(arr3); } + +#[entry] +fn concat_switch<n: usize>(b: i32, m: i32[n]) -> i32[n + 2] { + let ex : i32[2]; + ex[0] = 0; + ex[1] = 1; + + let x = concat::<_, 2, n>(ex, m); + let y = concat::<_, n, 2>(m, ex); + + let s = 0; + + s += sum::<i32, n + 2>(x); + s += sum::<i32, 2 + n>(x); + s += sum::<i32, n + 2>(y); + s += sum::<i32, 2 + n>(y); + + return if s < b then x else y; +} diff --git a/juno_samples/concat/src/main.rs b/juno_samples/concat/src/main.rs index 78932421df9890fbebfa4e136b4f85f875166130..83534c9d605c1a61c7854ac1e76d9bd7fa596a50 100644 --- a/juno_samples/concat/src/main.rs +++ b/juno_samples/concat/src/main.rs @@ -2,32 +2,16 @@ use hercules_rt::runner; use hercules_rt::HerculesCPURef; -#[cfg(feature = "cuda")] -use hercules_rt::CUDABox; juno_build::juno!("concat"); fn main() { async_std::task::block_on(async { let mut r = runner!(concat_entry); - let mut a_data = [7, 7, 0]; - let mut b_data = [7, 7, 0, 0, 7, 7]; - #[cfg(not(feature = "cuda"))] - { - let a = HerculesCPURef::from_slice(&mut a_data); - let b = HerculesCPURef::from_slice(&mut b_data); - let output = r.run(3, 6, a, b).await; - assert_eq!(output, 42); - } - #[cfg(feature = "cuda")] - { - let mut a_data = [7, 7, 0]; - let a = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a_data)); - let mut b_data = [7, 7, 0, 0, 7, 7]; - let b = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b_data)); - let output = r.run(3, 6, a.get_ref(), b.get_ref()).await; - assert_eq!(output, 42); - } + let output = r.run(7).await; + println!("{}", output); + assert_eq!(output, 42); + }); } diff --git a/juno_samples/patterns/Cargo.toml b/juno_samples/patterns/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a8dda157ff331ae9b1c5e1cb2a120db9bab3bb82 --- /dev/null +++ b/juno_samples/patterns/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "juno_patterns" +version = "0.1.0" +authors = ["Aaron Councilman <aaronjc4@illinois.edu>"] +edition = "2021" + +[[bin]] +name = "juno_patterns" +path = "src/main.rs" + +[build-dependencies] +juno_build = { path = "../../juno_build" } + +[dependencies] +juno_build = { path = "../../juno_build" } +hercules_rt = { path = "../../hercules_rt" } +with_builtin_macros = "0.1.0" +async-std = "*" diff --git a/juno_samples/patterns/build.rs b/juno_samples/patterns/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..8ac92f003e549a9aeb289001b8792ee4dcb51284 --- /dev/null +++ b/juno_samples/patterns/build.rs @@ -0,0 +1,9 @@ +use juno_build::JunoCompiler; + +fn main() { + JunoCompiler::new() + .file_in_src("patterns.jn") + .unwrap() + .build() + .unwrap(); +} diff --git a/juno_samples/patterns/src/main.rs b/juno_samples/patterns/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..5cc2e7c874c590ab4f7ce313b697ea9ed3ae3a30 --- /dev/null +++ b/juno_samples/patterns/src/main.rs @@ -0,0 +1,19 @@ +#![feature(concat_idents)] + +use hercules_rt::{runner}; + +juno_build::juno!("patterns"); + +fn main() { + async_std::task::block_on(async { + let mut r = runner!(entry); + let c = r.run(3, 8.0).await; + println!("{}", c); + assert_eq!(c, 14.0); + }); +} + +#[test] +fn simple3_test() { + main(); +} diff --git a/juno_samples/patterns/src/patterns.jn b/juno_samples/patterns/src/patterns.jn new file mode 100644 index 0000000000000000000000000000000000000000..923c258d202bca3d23c7bbd8b06531755c954ac2 --- /dev/null +++ b/juno_samples/patterns/src/patterns.jn @@ -0,0 +1,12 @@ +type Record = struct { a: i32; b: f64; }; + +#[entry] +fn entry(x: i32, f: f64) -> f64 { + let r = Record { a=x, b=f }; + let p = (f, f, x); + + let Record { a, .. } = r; + let (_, b, c) = p; + + return (a + c) as f64 + b; +}