Skip to content
Snippets Groups Projects
collections.rs 13.48 KiB
use std::collections::{BTreeMap, BTreeSet};

use crate::*;

/*
 * Analysis result that finds "collection objects" in Hercules IR. This analysis
 * is inter-procedural, since collection objects can be passed / returned to /
 * from called functions. This analysis also tracks which collection objects are
 * mutated "in" a function - a collection object is mutated in a function if
 * that function contains a write node that may write to that collection object
 * or if that function contains a call node that may take that collection object
 * as an argument and that collection parameter of that function is mutated.
 * Collection objects are numbered locally - the following nodes may originate a
 * collection object:
 *
 * - Parameter: each parameter index gets assigned a single collection object,
 *   each parameter node gets assigned the object of its index.
 * - Constant: each collection constant node gets assigned a single collection
 *   object.
 * - Call: each function is analyzed to determine which collection objects (of
 *   its parameters or an object it originates) may be returned; a call node
 *   originates a new collection object if it may return an object originated
 *   inside the callee.
 * - Undef: each undef node with a non-primitive type gets assigned a single
 *   collection object.
 *
 * The analysis contains the following information:
 *
 * - For each node in each function, which collection objects may be on the
 *   output of the node?
 * - For each function, which collection objects may be mutated inside that
 *   function, and by what nodes?
 * - For each function, which collection objects may be returned?
 * - For each collection object, how was it originated?
 */
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CollectionObjectOrigin {
    Parameter(usize),
    Constant(NodeID),
    Call(NodeID),
    Undef(NodeID),
}

define_id_type!(CollectionObjectID);

#[derive(Debug, Clone)]
pub struct FunctionCollectionObjects {
    objects_per_node: Vec<Vec<CollectionObjectID>>,
    mutated: Vec<Vec<NodeID>>,
    returned: Vec<CollectionObjectID>,
    origins: Vec<CollectionObjectOrigin>,
}

pub type CollectionObjects = BTreeMap<FunctionID, FunctionCollectionObjects>;

impl CollectionObjectOrigin {
    pub fn try_parameter(&self) -> Option<usize> {
        match self {
            CollectionObjectOrigin::Parameter(index) => Some(*index),
            _ => None,
        }
    }
}

impl FunctionCollectionObjects {
    pub fn objects(&self, id: NodeID) -> &Vec<CollectionObjectID> {
        &self.objects_per_node[id.idx()]
    }

    pub fn origin(&self, object: CollectionObjectID) -> CollectionObjectOrigin {
        self.origins[object.idx()]
    }

    pub fn param_to_object(&self, index: usize) -> Option<CollectionObjectID> {
        self.origins
            .iter()
            .position(|origin| *origin == CollectionObjectOrigin::Parameter(index))
            .map(CollectionObjectID::new)
    }

    pub fn returned_objects(&self) -> &Vec<CollectionObjectID> {
        &self.returned
    }

    pub fn is_mutated(&self, object: CollectionObjectID) -> bool {
        !self.mutators(object).is_empty()
    }

    pub fn mutators(&self, object: CollectionObjectID) -> &Vec<NodeID> {
        &self.mutated[object.idx()]
    }

    pub fn num_objects(&self) -> usize {
        self.origins.len()
    }

    pub fn iter_objects(&self) -> impl Iterator<Item = CollectionObjectID> {
        (0..self.num_objects()).map(CollectionObjectID::new)
    }
}

/*
 * Each node is assigned a set of collection objects output-ed from the node.
 * This is just a set of collection object IDs (usize).
 */
#[derive(PartialEq, Eq, Clone, Debug)]
struct CollectionObjectLattice {
    objs: BTreeSet<CollectionObjectID>,
}

impl Semilattice for CollectionObjectLattice {
    fn meet(a: &Self, b: &Self) -> Self {
        CollectionObjectLattice {
            objs: a.objs.union(&b.objs).map(|x| *x).collect(),
        }
    }

    fn top() -> Self {
        CollectionObjectLattice {
            objs: BTreeSet::new(),
        }
    }

    fn bottom() -> Self {
        // Technically, this lattice is unbounded - technically technically, the
        // lattice is bounded by the number of collection objects in a given
        // function, but incorporating this information is not possible in our
        // Semilattice inferface. Luckily bottom() isn't necessary if we never
        // call it, which we don't for this analysis.
        panic!()
    }
}

/*
 * Top level function to analyze collection objects in a Hercules module.
 */
pub fn collection_objects(
    functions: &Vec<Function>,
    types: &Vec<Type>,
    reverse_postorders: &Vec<Vec<NodeID>>,
    typing: &ModuleTyping,
    callgraph: &CallGraph,
) -> CollectionObjects {
    // Analyze functions in reverse topological order, since the analysis of a
    // function depends on all functions it calls.
    let mut collection_objects: CollectionObjects = BTreeMap::new();
    let topo = callgraph.topo();

    for func_id in topo {
        let func = &functions[func_id.idx()];
        let typing = &typing[func_id.idx()];
        let reverse_postorder = &reverse_postorders[func_id.idx()];

        // Find collection objects originating at parameters, constants, calls,
        // or undefs. Each node may *originate* one collection object.
        let param_origins = func
            .param_types
            .iter()
            .enumerate()
            .filter(|(_, ty_id)| !types[ty_id.idx()].is_primitive())
            .map(|(idx, _)| CollectionObjectOrigin::Parameter(idx));
        let other_origins = func
            .nodes
            .iter()
            .enumerate()
            .filter_map(|(idx, node)| match node {
                Node::Constant { id: _ } if !types[typing[idx].idx()].is_primitive() => {
                    Some(CollectionObjectOrigin::Constant(NodeID::new(idx)))
                }
                Node::Call {
                    control: _,
                    function: callee,
                    dynamic_constants: _,
                    args: _,
                } if {
                    let fco = &collection_objects[&callee];
                    fco.returned
                        .iter()
                        .any(|returned| fco.origins[returned.idx()].try_parameter().is_none())
                } =>
                {
                    // If the callee may return a new collection object, then
                    // this call node originates a single collection object. The
                    // node may output multiple collection objects, say if the
                    // callee may return an object passed in as a parameter -
                    // this is determined later.
                    Some(CollectionObjectOrigin::Call(NodeID::new(idx)))
                }
                Node::Undef { ty: _ } if !types[typing[idx].idx()].is_primitive() => {
                    Some(CollectionObjectOrigin::Undef(NodeID::new(idx)))
                }
                _ => None,
            });
        let origins: Vec<_> = param_origins.chain(other_origins).collect();

        // Run dataflow analysis to figure out which collection objects each
        // data node may output. Note that there's a strict subset of data nodes
        // that can output collection objects:
        //
        // - Phi: selects between objects in SSA form, may be assigned multiple
        //   possible objects.
        // - Reduce: reduces over an object, similar to phis.
        // - Parameter: may originate an object.
        // - Constant: may originate an object.
        // - Call: may originate an object and may return an object passed in as
        //   a parameter.
        // - Read: may extract a smaller object from the input - this is
        //   considered to be the same object as the input, as no copy takes
        //   place.
        // - Write: updates an object - this is considered to be the same object
        //   as the input object, as the write gets lowered to an in-place
        //   mutation.
        // - Undef: may originate a dummy object.
        // - Ternary (select): selects between two objects, may output either.
        let lattice = forward_dataflow(func, reverse_postorder, |inputs, id| {
            match func.nodes[id.idx()] {
                Node::Phi {
                    control: _,
                    data: _,
                }
                | Node::Reduce {
                    control: _,
                    init: _,
                    reduct: _,
                }
                | Node::Ternary {
                    op: TernaryOperator::Select,
                    first: _,
                    second: _,
                    third: _,
                } => inputs
                    .into_iter()
                    .fold(CollectionObjectLattice::top(), |acc, input| {
                        CollectionObjectLattice::meet(&acc, input)
                    }),
                Node::Parameter { index } => {
                    let obj = origins
                        .iter()
                        .position(|origin| *origin == CollectionObjectOrigin::Parameter(index))
                        .map(CollectionObjectID::new);
                    CollectionObjectLattice {
                        objs: obj.into_iter().collect(),
                    }
                }
                Node::Constant { id: _ } => {
                    let obj = origins
                        .iter()
                        .position(|origin| *origin == CollectionObjectOrigin::Constant(id))
                        .map(CollectionObjectID::new);
                    CollectionObjectLattice {
                        objs: obj.into_iter().collect(),
                    }
                }
                Node::Call {
                    control: _,
                    function: callee,
                    dynamic_constants: _,
                    args: _,
                } if !types[typing[id.idx()].idx()].is_primitive() => {
                    let new_obj = origins
                        .iter()
                        .position(|origin| *origin == CollectionObjectOrigin::Call(id))
                        .map(CollectionObjectID::new);
                    let fco = &collection_objects[&callee];
                    let param_objs = fco
                        .returned
                        .iter()
                        .filter_map(|returned| fco.origins[returned.idx()].try_parameter())
                        .map(|param_index| inputs[param_index + 1]);

                    let mut objs: BTreeSet<_> = new_obj.into_iter().collect();
                    for param_objs in param_objs {
                        objs.extend(&param_objs.objs);
                    }
                    CollectionObjectLattice { objs }
                }
                Node::Undef { ty: _ } => {
                    let obj = origins
                        .iter()
                        .position(|origin| *origin == CollectionObjectOrigin::Undef(id))
                        .map(CollectionObjectID::new);
                    CollectionObjectLattice {
                        objs: obj.into_iter().collect(),
                    }
                }
                Node::Read {
                    collect: _,
                    indices: _,
                } if !types[typing[id.idx()].idx()].is_primitive() => inputs[0].clone(),
                Node::Write {
                    collect: _,
                    data: _,
                    indices: _,
                } => inputs[0].clone(),
                _ => CollectionObjectLattice::top(),
            }
        });
        let objects_per_node: Vec<Vec<_>> = lattice
            .into_iter()
            .map(|l| l.objs.into_iter().collect())
            .collect();

        // Look at the collection objects that each return may take as input.
        let mut returned: BTreeSet<CollectionObjectID> = BTreeSet::new();
        for node in func.nodes.iter() {
            if let Node::Return { control: _, data } = node {
                returned.extend(&objects_per_node[data.idx()]);
            }
        }
        let returned = returned.into_iter().collect();

        // Determine which objects are potentially mutated.
        let mut mutated = vec![vec![]; origins.len()];
        for (idx, node) in func.nodes.iter().enumerate() {
            if node.is_write() {
                // Every object that the write itself corresponds to is mutable
                // in this function.
                for object in objects_per_node[idx].iter() {
                    mutated[object.idx()].push(NodeID::new(idx));
                }
            } else if let Some((_, callee, _, args)) = node.try_call() {
                let fco = &collection_objects[&callee];
                for (param_idx, arg) in args.into_iter().enumerate() {
                    // If this parameter corresponds to an object and it's
                    // mutable in the callee...
                    if let Some(param_callee_object) = fco.param_to_object(param_idx)
                        && fco.is_mutated(param_callee_object)
                    {
                        // Then every object corresponding to the argument node
                        // in this function is mutable.
                        for object in objects_per_node[arg.idx()].iter() {
                            mutated[object.idx()].push(NodeID::new(idx));
                        }
                    }
                }
            }
        }

        let fco = FunctionCollectionObjects {
            objects_per_node,
            mutated,
            returned,
            origins,
        };
        collection_objects.insert(func_id, fco);
    }

    collection_objects
}