From a4ded058cca94c3140b25c5ef5cc63f0c36d4cba Mon Sep 17 00:00:00 2001 From: prrathi <prrathi10@gmail.com> Date: Wed, 29 Jan 2025 09:29:02 -0600 Subject: [PATCH] matmul dot chngs --- hercules_cg/src/lib.rs | 3 +++ hercules_rt/src/lib.rs | 26 +++++++++++++++++++++++ hercules_samples/dot/src/cpu.sch | 7 +++++++ hercules_samples/dot/src/gpu.sch | 6 ++++++ hercules_samples/dot/src/main.rs | 32 +++++++++++++++++++++-------- hercules_samples/matmul/src/cpu.sch | 2 ++ hercules_samples/matmul/src/gpu.sch | 2 ++ hercules_samples/matmul/src/main.rs | 27 +++++++++++++++++++----- 8 files changed, 92 insertions(+), 13 deletions(-) diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs index 6910df9e..dab4dbac 100644 --- a/hercules_cg/src/lib.rs +++ b/hercules_cg/src/lib.rs @@ -1,6 +1,7 @@ #![feature(if_let_guard, let_chains)] pub mod cpu; +pub mod gpu; pub mod rt; pub mod fork_tree; @@ -9,6 +10,8 @@ pub use crate::cpu::*; pub use crate::gpu::*; pub use crate::rt::*; +pub use crate::fork_tree::*; + use std::collections::BTreeMap; use hercules_ir::*; diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs index db2dee77..a23ab3e9 100644 --- a/hercules_rt/src/lib.rs +++ b/hercules_rt/src/lib.rs @@ -147,6 +147,19 @@ impl<'a> HerculesCPURefMut<'a> { #[cfg(feature = "cuda")] impl<'a> HerculesCUDARef<'a> { + pub fn to_cpu_ref<T>(self, dst: &mut [T]) -> HerculesCPURef<'a> { + unsafe { + let size = self.size; + let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap(); + __copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size); + HerculesCPURef { + ptr, + size, + _phantom: PhantomData, + } + } + } + pub unsafe fn __ptr(&self) -> *mut u8 { self.ptr.as_ptr() } @@ -174,6 +187,19 @@ impl<'a> HerculesCUDARefMut<'a> { } } + pub fn to_cpu_ref<T>(self, dst: &mut [T]) -> HerculesCPURef<'a> { + unsafe { + let size = self.size; + let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap(); + __copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size); + HerculesCPURef { + ptr, + size, + _phantom: PhantomData, + } + } + } + pub unsafe fn __ptr(&self) -> *mut u8 { self.ptr.as_ptr() } diff --git a/hercules_samples/dot/src/cpu.sch b/hercules_samples/dot/src/cpu.sch index 58a7266d..4c684da2 100644 --- a/hercules_samples/dot/src/cpu.sch +++ b/hercules_samples/dot/src/cpu.sch @@ -6,7 +6,14 @@ auto-outline(*); ip-sroa(*); sroa(*); +fork-split(*); unforkify(*); dce(*); +float-collections(*); +gvn(*); +phi-elim(*); +dce(*); + +infer-schedules(*); gcm(*); diff --git a/hercules_samples/dot/src/gpu.sch b/hercules_samples/dot/src/gpu.sch index 956eb996..a1a51088 100644 --- a/hercules_samples/dot/src/gpu.sch +++ b/hercules_samples/dot/src/gpu.sch @@ -9,5 +9,11 @@ host(dot); ip-sroa(*); sroa(*); dce(*); +float-collections(*); +gvn(*); +phi-elim(*); +dce(*); + +infer-schedules(*); gcm(*); diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs index 335e8909..4e651fa8 100644 --- a/hercules_samples/dot/src/main.rs +++ b/hercules_samples/dot/src/main.rs @@ -1,19 +1,35 @@ #![feature(concat_idents)] use hercules_rt::{runner, HerculesCPURef}; +#[cfg(feature = "cuda")] +use hercules_rt::CUDABox; juno_build::juno!("dot"); fn main() { async_std::task::block_on(async { - let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; - let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; - let a = HerculesCPURef::from_slice(&a); - let b = HerculesCPURef::from_slice(&b); - let mut r = runner!(dot); - let c = r.run(8, a, b).await; - println!("{}", c); - assert_eq!(c, 70.0); + let mut a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0]; + let mut b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]; + #[cfg(not(feature = "cuda"))] + { + let a = HerculesCPURef::from_slice(&a); + let b = HerculesCPURef::from_slice(&b); + let mut r = runner!(dot); + let c = r.run(8, a, b).await; + println!("{}", c); + assert_eq!(c, 70.0); + } + #[cfg(feature = "cuda")] + { + let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); + let a = a_box.get_ref(); + let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); + let b = b_box.get_ref(); + let mut r = runner!(dot); + let c = r.run(8, a, b).await; + println!("{}", c); + assert_eq!(c, 70.0); + } }); } diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch index f7891b9b..4c684da2 100644 --- a/hercules_samples/matmul/src/cpu.sch +++ b/hercules_samples/matmul/src/cpu.sch @@ -14,4 +14,6 @@ gvn(*); phi-elim(*); dce(*); +infer-schedules(*); + gcm(*); diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch index 2bdcc83c..c9d6b336 100644 --- a/hercules_samples/matmul/src/gpu.sch +++ b/hercules_samples/matmul/src/gpu.sch @@ -14,4 +14,6 @@ gvn(*); phi-elim(*); dce(*); +infer-schedules(*); + gcm(*); diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs index 8757a0fd..762644f1 100644 --- a/hercules_samples/matmul/src/main.rs +++ b/hercules_samples/matmul/src/main.rs @@ -3,6 +3,8 @@ use rand::random; use hercules_rt::{runner, HerculesCPURef}; +#[cfg(feature = "cuda")] +use hercules_rt::CUDABox; juno_build::juno!("matmul"); @@ -21,11 +23,26 @@ fn main() { } } } - let a = HerculesCPURef::from_slice(&mut a); - let b = HerculesCPURef::from_slice(&mut b); - let mut r = runner!(matmul); - let c = r.run(I as u64, J as u64, K as u64, a, b).await; - assert_eq!(c.as_slice::<i32>(), &*correct_c); + #[cfg(not(feature = "cuda"))] + { + let a = HerculesCPURef::from_slice(&mut a); + let b = HerculesCPURef::from_slice(&mut b); + let mut r = runner!(matmul); + let c = r.run(I as u64, J as u64, K as u64, a, b).await; + assert_eq!(c.as_slice::<i32>(), &*correct_c); + } + #[cfg(feature = "cuda")] + { + let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a)); + let a = a_box.get_ref(); + let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b)); + let b = b_box.get_ref(); + let mut r = runner!(matmul); + let c = r.run(I as u64, J as u64, K as u64, a, b).await; + let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice(); + c.to_cpu_ref(&mut c_cpu); + assert_eq!(c_cpu.as_ref(), correct_c.as_ref()); + } }); } -- GitLab