Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • llvm/hercules
1 result
Show changes
Commits on Source (10)
Showing
with 564 additions and 1199 deletions
......@@ -259,12 +259,6 @@ dependencies = [
"arrayvec",
]
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "bincode"
version = "1.3.3"
......@@ -291,9 +285,6 @@ name = "bitflags"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be"
dependencies = [
"serde",
]
[[package]]
name = "bitstream-io"
......@@ -367,6 +358,7 @@ name = "call"
version = "0.1.0"
dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"rand",
"with_builtin_macros",
......@@ -388,6 +380,7 @@ name = "ccp"
version = "0.1.0"
dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"rand",
"with_builtin_macros",
......@@ -639,6 +632,7 @@ version = "0.1.0"
dependencies = [
"async-std",
"clap",
"hercules_rt",
"juno_build",
"rand",
"with_builtin_macros",
......@@ -825,17 +819,6 @@ dependencies = [
"serde",
]
[[package]]
name = "hercules_driver"
version = "0.1.0"
dependencies = [
"clap",
"hercules_ir",
"hercules_opt",
"postcard",
"ron",
]
[[package]]
name = "hercules_ir"
version = "0.1.0"
......@@ -1013,6 +996,7 @@ name = "juno_casts_and_intrinsics"
version = "0.1.0"
dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"with_builtin_macros",
]
......@@ -1048,6 +1032,7 @@ dependencies = [
"hercules_ir",
"hercules_opt",
"juno_scheduler",
"juno_utils",
"lrlex",
"lrpar",
"num-rational",
......@@ -1087,14 +1072,29 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_schedule_test"
version = "0.1.0"
dependencies = [
"async-std",
"hercules_rt",
"juno_build",
"rand",
"with_builtin_macros",
]
[[package]]
name = "juno_scheduler"
version = "0.0.1"
dependencies = [
"cfgrammar",
"hercules_cg",
"hercules_ir",
"hercules_opt",
"juno_utils",
"lrlex",
"lrpar",
"tempfile",
]
[[package]]
......@@ -1107,6 +1107,13 @@ dependencies = [
"with_builtin_macros",
]
[[package]]
name = "juno_utils"
version = "0.1.0"
dependencies = [
"serde",
]
[[package]]
name = "kv-log-macro"
version = "1.0.7"
......@@ -1744,18 +1751,6 @@ version = "0.8.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a"
[[package]]
name = "ron"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94"
dependencies = [
"base64",
"bitflags 2.7.0",
"serde",
"serde_derive",
]
[[package]]
name = "rustc_version"
version = "0.4.1"
......
......@@ -6,8 +6,7 @@ members = [
"hercules_opt",
"hercules_rt",
"hercules_tools/hercules_driver",
"juno_utils",
"juno_frontend",
"juno_scheduler",
"juno_build",
......@@ -27,7 +26,7 @@ members = [
"juno_samples/nested_ccp",
"juno_samples/antideps",
"juno_samples/implicit_clone",
"juno_samples/cava",
"juno_samples/concat",
"juno_samples/cava",
"juno_samples/schedule_test",
]
......@@ -638,7 +638,7 @@ impl<'a> CPUContext<'a> {
fn codegen_index_math(
&self,
collect_name: &str,
collect_ty: TypeID,
mut collect_ty: TypeID,
indices: &[Index],
body: &mut String,
) -> Result<String, Error> {
......@@ -665,11 +665,16 @@ impl<'a> CPUContext<'a> {
get_type_alignment(&self.types, fields[*idx]),
body,
)?;
acc_ptr = Self::gep(collect_name, &acc_offset, body)?;
acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?;
collect_ty = fields[*idx];
}
Index::Variant(_) => {
Index::Variant(idx) => {
// The tag of a summation is at the end of the summation, so
// the variant pointer is just the base pointer. Do nothing.
let Type::Summation(ref variants) = self.types[collect_ty.idx()] else {
panic!()
};
collect_ty = variants[*idx];
}
Index::Position(ref pos) => {
let Type::Array(elem, ref dims) = self.types[collect_ty.idx()] else {
......@@ -690,7 +695,8 @@ impl<'a> CPUContext<'a> {
// Convert offset in # elements -> # bytes.
acc_offset = Self::multiply(&acc_offset, &elem_size, body)?;
acc_ptr = Self::gep(collect_name, &acc_offset, body)?;
acc_ptr = Self::gep(&acc_ptr, &acc_offset, body)?;
collect_ty = elem;
}
}
}
......
......@@ -10,18 +10,6 @@ pub use crate::rt::*;
use hercules_ir::*;
/*
* Basic block info consists of two things:
*
* 1. A map from node to block (named by control nodes).
* 2. For each node, which nodes are in its own block.
*
* Note that for #2, the structure is Vec<NodeID>, meaning the nodes are ordered
* inside the block. This order corresponds to the traversal order of the nodes
* in the block needed by the backend code generators.
*/
pub type BasicBlocks = (Vec<NodeID>, Vec<Vec<NodeID>>);
/*
* The alignment of a type does not depend on dynamic constants.
*/
......
......@@ -53,7 +53,7 @@ impl<'a> RTContext<'a> {
// Dump the function signature.
write!(
w,
"#[allow(unused_variables,unused_mut,unused_parens)]\nasync fn {}<'a>(",
"#[allow(unused_variables,unused_mut,unused_parens,unused_unsafe)]\nasync fn {}<'a>(",
func.name
)?;
let mut first_param = true;
......@@ -149,7 +149,7 @@ impl<'a> RTContext<'a> {
// blocks to drive execution.
write!(
w,
" let mut control_token: i8 = 0;\n loop {{\n match control_token {{\n",
" let mut control_token: i8 = 0;\n let return_value = loop {{\n match control_token {{\n",
)?;
let mut blocks: BTreeMap<_, _> = (0..func.nodes.len())
......@@ -182,8 +182,41 @@ impl<'a> RTContext<'a> {
)?;
}
// Close the match, loop, and function.
write!(w, " _ => panic!()\n }}\n }}\n}}\n")?;
// Close the match and loop.
write!(w, " _ => panic!()\n }}\n }};\n")?;
// Emit the epilogue of the function.
write!(w, " unsafe {{\n")?;
for idx in 0..func.param_types.len() {
if !self.module.types[func.param_types[idx].idx()].is_primitive() {
write!(w, " p{}.__forget();\n", idx)?;
}
}
if !self.module.types[func.return_type.idx()].is_primitive() {
for object in self.collection_objects[&self.func_id].iter_objects() {
if let CollectionObjectOrigin::Constant(_) =
self.collection_objects[&self.func_id].origin(object)
{
write!(
w,
" if obj{}.__cmp_ids(&return_value) {{\n",
object.idx()
)?;
write!(w, " obj{}.__forget();\n", object.idx())?;
write!(w, " }}\n")?;
}
}
}
for idx in 0..func.nodes.len() {
if !func.nodes[idx].is_control()
&& !self.module.types[self.typing[idx].idx()].is_primitive()
{
write!(w, " node_{}.__forget();\n", idx)?;
}
}
write!(w, " }}\n")?;
write!(w, " return_value\n")?;
write!(w, "}}\n")?;
Ok(())
}
......@@ -230,7 +263,15 @@ impl<'a> RTContext<'a> {
}
Node::Return { control: _, data } => {
let block = &mut blocks.get_mut(&id).unwrap();
write!(block, " return {};\n", self.get_value(data))?
if self.module.types[self.typing[data.idx()].idx()].is_primitive() {
write!(block, " break {};\n", self.get_value(data))?
} else {
write!(
block,
" break unsafe {{ {}.__clone() }};\n",
self.get_value(data)
)?
}
}
_ => panic!("PANIC: Can't lower {:?}.", func.nodes[id.idx()]),
}
......@@ -259,7 +300,7 @@ impl<'a> RTContext<'a> {
} else {
write!(
block,
" {} = unsafe {{ p{}.__take() }};\n",
" {} = unsafe {{ p{}.__clone() }};\n",
self.get_value(id),
index
)?
......@@ -284,7 +325,7 @@ impl<'a> RTContext<'a> {
let objects = self.collection_objects[&self.func_id].objects(id);
assert_eq!(objects.len(), 1);
let object = objects[0];
write!(block, "unsafe {{ obj{}.__take() }}", object.idx())?
write!(block, "unsafe {{ obj{}.__clone() }}", object.idx())?
}
}
write!(block, ";\n")?
......@@ -374,7 +415,7 @@ impl<'a> RTContext<'a> {
)?;
write!(
block,
" {} = unsafe {{ {}.__take() }};\n",
" {} = unsafe {{ {}.__clone() }};\n",
self.get_value(id),
self.get_value(*arg)
)?;
......@@ -407,13 +448,84 @@ impl<'a> RTContext<'a> {
if self.module.types[self.typing[arg.idx()].idx()].is_primitive() {
write!(block, "{}, ", self.get_value(*arg))?;
} else {
write!(block, "unsafe {{ {}.__take() }}, ", self.get_value(*arg))?;
write!(block, "unsafe {{ {}.__clone() }}, ", self.get_value(*arg))?;
}
}
write!(block, ").await;\n")?;
}
}
}
Node::Read {
collect,
ref indices,
} => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let collect_ty = self.typing[collect.idx()];
let out_size = self.codegen_type_size(self.typing[id.idx()]);
let offset = self.codegen_index_math(collect_ty, indices)?;
write!(
block,
" let mut read_offset_obj = unsafe {{ {}.__clone() }};\n",
self.get_value(collect)
)?;
write!(
block,
" unsafe {{ read_offset_obj.__offset({}, {}) }};\n",
offset, out_size,
)?;
if self.module.types[self.typing[id.idx()].idx()].is_primitive() {
write!(
block,
" {} = unsafe {{ *(read_offset_obj.__cpu_ptr() as *const _) }};\n",
self.get_value(id)
)?;
write!(
block,
" unsafe {{ read_offset_obj.__forget() }};\n",
)?;
} else {
write!(
block,
" {} = read_offset_obj;\n",
self.get_value(id)
)?;
}
}
Node::Write {
collect,
data,
ref indices,
} => {
let block = &mut blocks.get_mut(&self.bbs.0[id.idx()]).unwrap();
let collect_ty = self.typing[collect.idx()];
let data_size = self.codegen_type_size(self.typing[data.idx()]);
let offset = self.codegen_index_math(collect_ty, indices)?;
write!(
block,
" let mut write_offset_obj = unsafe {{ {}.__clone() }};\n",
self.get_value(collect)
)?;
write!(block, " let write_offset_ptr = unsafe {{ write_offset_obj.__cpu_ptr_mut().byte_add({}) }};\n", offset)?;
if self.module.types[self.typing[data.idx()].idx()].is_primitive() {
write!(
block,
" unsafe {{ *(write_offset_ptr as *mut _) = {} }};\n",
self.get_value(data)
)?;
} else {
write!(
block,
" unsafe {{ ::core::ptr::copy_nonoverlapping({}.__cpu_ptr(), write_offset_ptr as *mut _, {} as usize) }};\n",
self.get_value(data),
data_size,
)?;
}
write!(
block,
" {} = write_offset_obj;\n",
self.get_value(id),
)?;
}
_ => panic!(
"PANIC: Can't lower {:?} in {}.",
func.nodes[id.idx()],
......@@ -487,6 +599,78 @@ impl<'a> RTContext<'a> {
Ok(())
}
/*
* Emit logic to index into an collection.
*/
fn codegen_index_math(
&self,
mut collect_ty: TypeID,
indices: &[Index],
) -> Result<String, Error> {
let mut acc_offset = "0".to_string();
for index in indices {
match index {
Index::Field(idx) => {
let Type::Product(ref fields) = self.module.types[collect_ty.idx()] else {
panic!()
};
// Get the offset of the field at index `idx` by calculating
// the product's size up to field `idx`, then offseting the
// base pointer by that amount.
for field in &fields[..*idx] {
let field_align = get_type_alignment(&self.module.types, *field);
let field = self.codegen_type_size(*field);
acc_offset = format!(
"((({} + {}) & !{}) + {})",
acc_offset,
field_align - 1,
field_align - 1,
field
);
}
let last_align = get_type_alignment(&self.module.types, fields[*idx]);
acc_offset = format!(
"(({} + {}) & !{})",
acc_offset,
last_align - 1,
last_align - 1
);
collect_ty = fields[*idx];
}
Index::Variant(idx) => {
// The tag of a summation is at the end of the summation, so
// the variant pointer is just the base pointer. Do nothing.
let Type::Summation(ref variants) = self.module.types[collect_ty.idx()] else {
panic!()
};
collect_ty = variants[*idx];
}
Index::Position(ref pos) => {
let Type::Array(elem, ref dims) = self.module.types[collect_ty.idx()] else {
panic!()
};
// The offset of the position into an array is:
//
// ((0 * s1 + p1) * s2 + p2) * s3 + p3 ...
let elem_size = self.codegen_type_size(elem);
for (p, s) in zip(pos, dims) {
let p = self.get_value(*p);
acc_offset = format!("{} * ", acc_offset);
self.codegen_dynamic_constant(*s, &mut acc_offset)?;
acc_offset = format!("({} + {})", acc_offset, p);
}
// Convert offset in # elements -> # bytes.
acc_offset = format!("({} * {})", acc_offset, elem_size);
collect_ty = elem;
}
}
}
Ok(acc_offset)
}
/*
* Lower the size of a type into a Rust expression.
*/
......
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use crate::*;
......@@ -11,10 +11,11 @@ pub struct Builder<'a> {
// Intern function names.
function_ids: HashMap<&'a str, FunctionID>,
// Intern types, constants, and dynamic constants on a per-module basis.
// Intern types, constants, dynamic constants, and labels on a per-module basis.
interned_types: HashMap<Type, TypeID>,
interned_constants: HashMap<Constant, ConstantID>,
interned_dynamic_constants: HashMap<DynamicConstant, DynamicConstantID>,
interned_labels: HashMap<String, LabelID>,
// For product, summation, and array constant creation, it's useful to know
// the type of each constant.
......@@ -37,6 +38,7 @@ pub struct NodeBuilder {
function_id: FunctionID,
node: Node,
schedules: Vec<Schedule>,
labels: Vec<LabelID>,
}
/*
......@@ -79,6 +81,17 @@ impl<'a> Builder<'a> {
}
}
pub fn add_label(&mut self, label: &String) -> LabelID {
if let Some(id) = self.interned_labels.get(label) {
*id
} else {
let id = LabelID::new(self.interned_labels.len());
self.interned_labels.insert(label.clone(), id);
self.module.labels.push(label.clone());
id
}
}
pub fn create() -> Self {
Self::default()
}
......@@ -452,6 +465,10 @@ impl<'a> Builder<'a> {
Index::Position(idx)
}
pub fn get_labels(&self, func: FunctionID, node: NodeID) -> &HashSet<LabelID> {
&self.module.functions[func.idx()].labels[node.idx()]
}
pub fn create_function(
&mut self,
name: &str,
......@@ -473,6 +490,7 @@ impl<'a> Builder<'a> {
entry,
nodes: vec![Node::Start],
schedules: vec![vec![]],
labels: vec![HashSet::new()],
device: None,
});
Ok((id, NodeID::new(0)))
......@@ -484,11 +502,15 @@ impl<'a> Builder<'a> {
.nodes
.push(Node::Start);
self.module.functions[function.idx()].schedules.push(vec![]);
self.module.functions[function.idx()]
.labels
.push(HashSet::new());
NodeBuilder {
id,
function_id: function,
node: Node::Start,
schedules: vec![],
labels: vec![],
}
}
......@@ -499,6 +521,8 @@ impl<'a> Builder<'a> {
self.module.functions[builder.function_id.idx()].nodes[builder.id.idx()] = builder.node;
self.module.functions[builder.function_id.idx()].schedules[builder.id.idx()] =
builder.schedules;
self.module.functions[builder.function_id.idx()].labels[builder.id.idx()] =
builder.labels.into_iter().collect();
Ok(())
}
}
......@@ -617,4 +641,15 @@ impl NodeBuilder {
pub fn add_schedule(&mut self, schedule: Schedule) {
self.schedules.push(schedule);
}
pub fn add_label(&mut self, label: LabelID) {
self.labels.push(label);
}
pub fn add_labels<I>(&mut self, labels: I)
where
I: Iterator<Item = LabelID>,
{
self.labels.extend(labels);
}
}
......@@ -79,10 +79,9 @@ impl CallGraph {
/*
* Top level function to calculate the call graph of a Hercules module.
*/
pub fn callgraph(module: &Module) -> CallGraph {
pub fn callgraph(functions: &Vec<Function>) -> CallGraph {
// Step 1: collect the functions called in each function.
let callee_functions: Vec<Vec<FunctionID>> = module
.functions
let callee_functions: Vec<Vec<FunctionID>> = functions
.iter()
.map(|func| {
let mut called: Vec<_> = func
......
......@@ -135,7 +135,8 @@ impl Semilattice for CollectionObjectLattice {
* Top level function to analyze collection objects in a Hercules module.
*/
pub fn collection_objects(
module: &Module,
functions: &Vec<Function>,
types: &Vec<Type>,
reverse_postorders: &Vec<Vec<NodeID>>,
typing: &ModuleTyping,
callgraph: &CallGraph,
......@@ -146,7 +147,7 @@ pub fn collection_objects(
let topo = callgraph.topo();
for func_id in topo {
let func = &module.functions[func_id.idx()];
let func = &functions[func_id.idx()];
let typing = &typing[func_id.idx()];
let reverse_postorder = &reverse_postorders[func_id.idx()];
......@@ -156,14 +157,14 @@ pub fn collection_objects(
.param_types
.iter()
.enumerate()
.filter(|(_, ty_id)| !module.types[ty_id.idx()].is_primitive())
.filter(|(_, ty_id)| !types[ty_id.idx()].is_primitive())
.map(|(idx, _)| CollectionObjectOrigin::Parameter(idx));
let other_origins = func
.nodes
.iter()
.enumerate()
.filter_map(|(idx, node)| match node {
Node::Constant { id: _ } if !module.types[typing[idx].idx()].is_primitive() => {
Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => {
Some(CollectionObjectOrigin::Constant(NodeID::new(idx)))
}
Node::Call {
......@@ -185,7 +186,7 @@ pub fn collection_objects(
// this is determined later.
Some(CollectionObjectOrigin::Call(NodeID::new(idx)))
}
Node::Undef { ty: _ } if !module.types[typing[idx].idx()].is_primitive() => {
Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => {
Some(CollectionObjectOrigin::Undef(NodeID::new(idx)))
}
_ => None,
......@@ -255,7 +256,7 @@ pub fn collection_objects(
function: callee,
dynamic_constants: _,
args: _,
} if !module.types[typing[id.idx()].idx()].is_primitive() => {
} if !types[typing[id.idx()].idx()].is_primitive() => {
let new_obj = origins
.iter()
.position(|origin| *origin == CollectionObjectOrigin::Call(id))
......@@ -285,7 +286,7 @@ pub fn collection_objects(
Node::Read {
collect: _,
indices: _,
} if !module.types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(),
} if !types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(),
Node::Write {
collect: _,
data: _,
......
......@@ -18,6 +18,7 @@ pub fn xdot_module(
reverse_postorders: &Vec<Vec<NodeID>>,
doms: Option<&Vec<DomTree>>,
fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>,
bbs: Option<&Vec<BasicBlocks>>,
) {
let mut tmp_path = temp_dir();
let mut rng = rand::thread_rng();
......@@ -30,6 +31,7 @@ pub fn xdot_module(
&reverse_postorders,
doms,
fork_join_maps,
bbs,
&mut contents,
)
.expect("PANIC: Unable to generate output file contents.");
......@@ -51,6 +53,7 @@ pub fn write_dot<W: Write>(
reverse_postorders: &Vec<Vec<NodeID>>,
doms: Option<&Vec<DomTree>>,
fork_join_maps: Option<&Vec<HashMap<NodeID, NodeID>>>,
bbs: Option<&Vec<BasicBlocks>>,
w: &mut W,
) -> std::fmt::Result {
write_digraph_header(w)?;
......@@ -165,6 +168,26 @@ pub fn write_dot<W: Write>(
}
}
// Step 4: draw basic block edges in indigo.
if let Some(bbs) = bbs {
let bbs = &bbs[function_id.idx()].0;
for (idx, bb) in bbs.into_iter().enumerate() {
if idx != bb.idx() {
write_edge(
NodeID::new(idx),
function_id,
*bb,
function_id,
true,
"indigo",
"dotted",
&module,
w,
)?;
}
}
}
write_graph_footer(w)?;
}
......
use std::cmp::{max, min};
use std::collections::HashSet;
use std::fmt::Write;
use std::ops::Coroutine;
use std::ops::CoroutineState;
......@@ -23,6 +24,7 @@ pub struct Module {
pub types: Vec<Type>,
pub constants: Vec<Constant>,
pub dynamic_constants: Vec<DynamicConstant>,
pub labels: Vec<String>,
}
/*
......@@ -43,7 +45,8 @@ pub struct Function {
pub nodes: Vec<Node>,
pub schedules: FunctionSchedule,
pub schedules: FunctionSchedules,
pub labels: FunctionLabels,
pub device: Option<Device>,
}
......@@ -341,7 +344,24 @@ pub enum Device {
/*
* A single node may have multiple schedules.
*/
pub type FunctionSchedule = Vec<Vec<Schedule>>;
pub type FunctionSchedules = Vec<Vec<Schedule>>;
/*
* A single node may have multiple labels.
*/
pub type FunctionLabels = Vec<HashSet<LabelID>>;
/*
* Basic block info consists of two things:
*
* 1. A map from node to block (named by control nodes).
* 2. For each node, which nodes are in its own block.
*
* Note that for #2, the structure is Vec<NodeID>, meaning the nodes are ordered
* inside the block. This order corresponds to the traversal order of the nodes
* in the block needed by the backend code generators.
*/
pub type BasicBlocks = (Vec<NodeID>, Vec<Vec<NodeID>>);
impl Module {
/*
......@@ -734,6 +754,7 @@ impl Function {
// Step 4: update the schedules.
self.schedules.fix_gravestones(&node_mapping);
self.labels.fix_gravestones(&node_mapping);
node_mapping
}
......@@ -1767,3 +1788,4 @@ define_id_type!(NodeID);
define_id_type!(TypeID);
define_id_type!(ConstantID);
define_id_type!(DynamicConstantID);
define_id_type!(LabelID);
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use crate::*;
......@@ -124,6 +124,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a
entry: true,
nodes: vec![],
schedules: vec![],
labels: vec![],
device: None,
};
context.function_ids.len()
......@@ -157,6 +158,7 @@ fn parse_module<'a>(ir_text: &'a str, context: Context<'a>) -> nom::IResult<&'a
types,
constants,
dynamic_constants,
labels: vec![],
};
Ok((rest, module))
}
......@@ -262,6 +264,7 @@ fn parse_function<'a>(
entry: true,
nodes: fixed_nodes,
schedules: vec![vec![]; num_nodes],
labels: vec![HashSet::new(); num_nodes],
device: None,
},
))
......
......@@ -82,19 +82,16 @@ pub type ModuleTyping = Vec<Vec<TypeID>>;
* Returns a type for every node in every function.
*/
pub fn typecheck(
module: &mut Module,
functions: &Vec<Function>,
types: &mut Vec<Type>,
constants: &Vec<Constant>,
dynamic_constants: &mut Vec<DynamicConstant>,
reverse_postorders: &Vec<Vec<NodeID>>,
) -> Result<ModuleTyping, String> {
// Step 1: assemble a reverse type map. This is needed to get or create the
// ID of potentially new types. Break down module into references to
// individual elements at this point, so that borrows don't overlap each
// other.
let Module {
ref functions,
ref mut types,
ref constants,
ref mut dynamic_constants,
} = module;
let mut reverse_type_map: HashMap<Type, TypeID> = types
.iter()
.enumerate()
......
......@@ -29,7 +29,13 @@ pub fn verify(
let reverse_postorders: Vec<_> = def_uses.iter().map(reverse_postorder).collect();
// Typecheck the module.
let typing = typecheck(module, &reverse_postorders)?;
let typing = typecheck(
&module.functions,
&mut module.types,
&module.constants,
&mut module.dynamic_constants,
&reverse_postorders,
)?;
// Assemble fork join maps for module.
let subgraphs: Vec<_> = zip(module.functions.iter(), def_uses.iter())
......
use hercules_ir::*;
use crate::*;
/*
* Top level function to collapse read chains in a function.
*/
pub fn crc(editor: &mut FunctionEditor) {
let mut changed = true;
while changed {
changed = false;
for id in editor.node_ids() {
if let Node::Read {
collect: lower_collect,
indices: ref lower_indices,
} = editor.func().nodes[id.idx()]
&& let Node::Read {
collect: upper_collect,
indices: ref upper_indices,
} = editor.func().nodes[lower_collect.idx()]
{
let collapsed_read = Node::Read {
collect: upper_collect,
indices: upper_indices
.iter()
.chain(lower_indices.iter())
.map(|idx| idx.clone())
.collect(),
};
let success = editor.edit(|mut edit| {
let new_id = edit.add_node(collapsed_read);
let edit = edit.replace_all_uses(id, new_id)?;
edit.delete_node(id)
});
changed = changed || success;
}
}
}
}
......@@ -26,6 +26,9 @@ pub struct FunctionEditor<'a> {
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
// Keep a RefCell to the string table that tracks labels, so that new labels
// can be added as needed
labels: &'a RefCell<Vec<String>>,
// Most optimizations need def use info, so provide an iteratively updated
// mutable version that's automatically updated based on recorded edits.
mut_def_use: Vec<HashSet<NodeID>>,
......@@ -34,6 +37,9 @@ pub struct FunctionEditor<'a> {
// are off limits for deletion (equivalently modification) or being replaced
// as a use.
mutable_nodes: BitVec<u8, Lsb0>,
// Tracks whether this editor has been used to make any edits to the IR of
// this function
modified: bool,
}
/*
......@@ -51,10 +57,13 @@ pub struct FunctionEdit<'a: 'b, 'b> {
added_and_updated_nodes: BTreeMap<NodeID, Node>,
// Keep track of added and updated schedules.
added_and_updated_schedules: BTreeMap<NodeID, Vec<Schedule>>,
// Keep track of added (dynamic) constants and types
// Keep track of added and updated labels.
added_and_updated_labels: BTreeMap<NodeID, HashSet<LabelID>>,
// Keep track of added (dynamic) constants, types, and labels
added_constants: Vec<Constant>,
added_dynamic_constants: Vec<DynamicConstant>,
added_types: Vec<Type>,
added_labels: Vec<String>,
// Compute a def-use map entries iteratively.
updated_def_use: BTreeMap<NodeID, HashSet<NodeID>>,
updated_param_types: Option<Vec<TypeID>>,
......@@ -70,6 +79,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
labels: &'a RefCell<Vec<String>>,
def_use: &ImmutableDefUseMap,
) -> Self {
let mut_def_use = (0..function.nodes.len())
......@@ -89,11 +99,60 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
constants,
dynamic_constants,
types,
labels,
mut_def_use,
mutable_nodes,
modified: false,
}
}
// Constructs an editor but only makes the nodes with at least one of the set of labels as
// mutable
pub fn new_labeled(
function: &'a mut Function,
function_id: FunctionID,
constants: &'a RefCell<Vec<Constant>>,
dynamic_constants: &'a RefCell<Vec<DynamicConstant>>,
types: &'a RefCell<Vec<Type>>,
labels: &'a RefCell<Vec<String>>,
def_use: &ImmutableDefUseMap,
with_labels: &HashSet<LabelID>,
) -> Self {
let mut_def_use = (0..function.nodes.len())
.map(|idx| {
def_use
.get_users(NodeID::new(idx))
.into_iter()
.map(|x| *x)
.collect()
})
.collect();
let mut mutable_nodes = bitvec![u8, Lsb0; 0; function.nodes.len()];
// Add all nodes which have some label which is in the with_labels set
for (idx, labels) in function.labels.iter().enumerate() {
if !labels.is_disjoint(with_labels) {
mutable_nodes.set(idx, true);
}
}
FunctionEditor {
function,
function_id,
constants,
dynamic_constants,
types,
labels,
mut_def_use,
mutable_nodes,
modified: false,
}
}
pub fn modified(&self) -> bool {
self.modified
}
pub fn edit<F>(&'b mut self, edit: F) -> bool
where
F: FnOnce(FunctionEdit<'a, 'b>) -> Result<FunctionEdit<'a, 'b>, FunctionEdit<'a, 'b>>,
......@@ -105,9 +164,11 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
added_nodeids: HashSet::new(),
added_and_updated_nodes: BTreeMap::new(),
added_and_updated_schedules: BTreeMap::new(),
added_and_updated_labels: BTreeMap::new(),
added_constants: Vec::new().into(),
added_dynamic_constants: Vec::new().into(),
added_types: Vec::new().into(),
added_labels: Vec::new().into(),
updated_def_use: BTreeMap::new(),
updated_param_types: None,
updated_return_type: None,
......@@ -120,17 +181,28 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
let FunctionEdit {
editor,
deleted_nodeids,
added_nodeids: _,
added_nodeids,
added_and_updated_nodes,
added_and_updated_schedules,
added_and_updated_labels,
added_constants,
added_dynamic_constants,
added_types,
added_labels,
updated_def_use,
updated_param_types,
updated_return_type,
sub_edits,
} = populated_edit;
// Step 0: determine whether the edit changed the IR by checking if
// any nodes were deleted, added, or updated in any way
editor.modified |= !deleted_nodeids.is_empty()
|| !added_nodeids.is_empty()
|| !added_and_updated_nodes.is_empty()
|| !added_and_updated_schedules.is_empty()
|| !added_and_updated_labels.is_empty();
// Step 1: update the mutable def use map.
for (u, new_users) in updated_def_use {
// Go through new def-use entries in order. These are either
......@@ -160,7 +232,7 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
}
}
// Step 3: add and update schedules.
// Step 3.0: add and update schedules.
editor
.function
.schedules
......@@ -169,6 +241,15 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
editor.function.schedules[id.idx()] = schedule;
}
// Step 3.1: add and update labels.
editor
.function
.labels
.resize(editor.function.nodes.len(), HashSet::new());
for (id, label) in added_and_updated_labels {
editor.function.labels[id.idx()] = label;
}
// Step 4: delete nodes. This is done using "gravestones", where a
// node other than node ID 0 being a start node is considered a
// gravestone.
......@@ -178,8 +259,8 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
editor.function.nodes[id.idx()] = Node::Start;
}
// Step 5: propagate schedules along sub-edit edges.
for (src, dst) in sub_edits {
// Step 5.0: propagate schedules along sub-edit edges.
for (src, dst) in sub_edits.iter() {
let mut dst_schedules = take(&mut editor.function.schedules[dst.idx()]);
for src_schedule in editor.function.schedules[src.idx()].iter() {
if !dst_schedules.contains(src_schedule) {
......@@ -189,6 +270,32 @@ impl<'a: 'b, 'b> FunctionEditor<'a> {
editor.function.schedules[dst.idx()] = dst_schedules;
}
// Step 5.1: update and propagate labels
editor.labels.borrow_mut().extend(added_labels);
// We propagate labels in two steps, first along sub-edits and then
// all the labels on any deleted node not used in any sub-edit to all
// added nodes not in any sub-edit
let mut sources = deleted_nodeids.clone();
let mut dests = added_nodeids.clone();
for (src, dst) in sub_edits {
let mut dst_labels = take(&mut editor.function.labels[dst.idx()]);
dst_labels.extend(editor.function.labels[src.idx()].iter());
editor.function.labels[dst.idx()] = dst_labels;
sources.remove(&src);
dests.remove(&dst);
}
let mut src_labels = HashSet::new();
for src in sources {
src_labels.extend(editor.function.labels[src.idx()].clone());
}
for dst in dests {
editor.function.labels[dst.idx()].extend(src_labels.clone());
}
// Step 6: update the length of mutable_nodes. All added nodes are
// mutable.
editor
......@@ -446,6 +553,57 @@ impl<'a, 'b> FunctionEdit<'a, 'b> {
}
}
pub fn get_label(&self, id: NodeID) -> &HashSet<LabelID> {
// The user may get the labels of a to-be deleted node.
if let Some(label) = self.added_and_updated_labels.get(&id) {
// Refer to added or updated label.
label
} else {
// Refer to the origin label of this code.
&self.editor.function.labels[id.idx()]
}
}
pub fn add_label(mut self, id: NodeID, label: LabelID) -> Result<Self, Self> {
if self.is_mutable(id) {
if let Some(labels) = self.added_and_updated_labels.get_mut(&id) {
labels.insert(label);
} else {
let mut labels = self.editor.function.labels[id.idx()].clone();
labels.insert(label);
self.added_and_updated_labels.insert(id, labels);
}
Ok(self)
} else {
Err(self)
}
}
// Creates or returns the LabelID for a given label name
pub fn new_label(&mut self, name: String) -> LabelID {
let pos = self
.editor
.labels
.borrow()
.iter()
.chain(self.added_labels.iter())
.position(|l| *l == name);
if let Some(idx) = pos {
LabelID::new(idx)
} else {
let idx = self.editor.labels.borrow().len() + self.added_labels.len();
self.added_labels.push(name);
LabelID::new(idx)
}
}
// Creates an entirely fresh label and returns its LabelID
pub fn fresh_label(&mut self) -> LabelID {
let idx = self.editor.labels.borrow().len() + self.added_labels.len();
self.added_labels.push(format!("#fresh_{}", idx));
LabelID::new(idx)
}
pub fn get_users(&self, id: NodeID) -> impl Iterator<Item = NodeID> + '_ {
assert!(!self.deleted_nodeids.contains(&id));
if let Some(users) = self.updated_def_use.get(&id) {
......@@ -671,6 +829,7 @@ fn func(x: i32) -> i32
let constants_ref = RefCell::new(src_module.constants);
let dynamic_constants_ref = RefCell::new(src_module.dynamic_constants);
let types_ref = RefCell::new(src_module.types);
let labels_ref = RefCell::new(src_module.labels);
// Edit the function by replacing the add with a multiply.
let mut editor = FunctionEditor::new(
func,
......@@ -678,6 +837,7 @@ fn func(x: i32) -> i32
&constants_ref,
&dynamic_constants_ref,
&types_ref,
&labels_ref,
&def_use(func),
);
let success = editor.edit(|mut edit| {
......
......@@ -209,6 +209,11 @@ fn inline_func(
for schedule in callee_schedule {
edit = edit.add_schedule(add_id, schedule.clone())?;
}
// Copy the labels from the callee.
let callee_labels = &called_func.labels[idx];
for label in callee_labels {
edit = edit.add_label(add_id, *label)?;
}
}
// Stitch the control use of the inlined start node with the
......
#![feature(let_chains)]
pub mod ccp;
pub mod crc;
pub mod dce;
pub mod delete_uncalled;
pub mod editor;
......@@ -13,7 +14,6 @@ pub mod gvn;
pub mod inline;
pub mod interprocedural_sroa;
pub mod outline;
pub mod pass;
pub mod phi_elim;
pub mod pred;
pub mod schedule;
......@@ -23,6 +23,7 @@ pub mod unforkify;
pub mod utils;
pub use crate::ccp::*;
pub use crate::crc::*;
pub use crate::dce::*;
pub use crate::delete_uncalled::*;
pub use crate::editor::*;
......@@ -35,7 +36,6 @@ pub use crate::gvn::*;
pub use crate::inline::*;
pub use crate::interprocedural_sroa::*;
pub use crate::outline::*;
pub use crate::pass::*;
pub use crate::phi_elim::*;
pub use crate::pred::*;
pub use crate::schedule::*;
......
use std::collections::{BTreeMap, BTreeSet};
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::iter::zip;
use std::sync::atomic::{AtomicUsize, Ordering};
......@@ -203,6 +203,7 @@ pub fn outline(
entry: false,
nodes: vec![],
schedules: vec![],
labels: vec![],
device: None,
};
......@@ -420,6 +421,13 @@ pub fn outline(
outlined.schedules[callee_id.idx()] = edit.get_schedule(*id).clone();
}
// Copy the labels into the new callee.
outlined.labels.resize(outlined.nodes.len(), HashSet::new());
for id in partition.iter() {
let callee_id = convert_id(*id);
outlined.labels[callee_id.idx()] = edit.get_label(*id).clone();
}
// Step 3: edit the original function to call the outlined function.
let dynamic_constants = (0..edit.get_num_dynamic_constant_params())
.map(|idx| edit.add_dynamic_constant(DynamicConstant::Parameter(idx as usize)))
......
This diff is collapsed.
......@@ -243,7 +243,7 @@ pub(crate) fn substitute_dynamic_constants_in_node(
/*
* Top level function to make a function have only a single return.
*/
pub(crate) fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
pub fn collapse_returns(editor: &mut FunctionEditor) -> Option<NodeID> {
let returns: Vec<NodeID> = (0..editor.func().nodes.len())
.filter(|idx| editor.func().nodes[*idx].is_return())
.map(NodeID::new)
......@@ -293,7 +293,7 @@ pub(crate) fn contains_between_control_flow(func: &Function) -> bool {
* Top level function to ensure a Hercules function contains at least one
* control node that isn't the start or return nodes.
*/
pub(crate) fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
pub fn ensure_between_control_flow(editor: &mut FunctionEditor) -> Option<NodeID> {
if !contains_between_control_flow(editor.func()) {
let ret = editor
.node_ids()
......