From a4ded058cca94c3140b25c5ef5cc63f0c36d4cba Mon Sep 17 00:00:00 2001
From: prrathi <prrathi10@gmail.com>
Date: Wed, 29 Jan 2025 09:29:02 -0600
Subject: [PATCH] matmul dot chngs

---
 hercules_cg/src/lib.rs              |  3 +++
 hercules_rt/src/lib.rs              | 26 +++++++++++++++++++++++
 hercules_samples/dot/src/cpu.sch    |  7 +++++++
 hercules_samples/dot/src/gpu.sch    |  6 ++++++
 hercules_samples/dot/src/main.rs    | 32 +++++++++++++++++++++--------
 hercules_samples/matmul/src/cpu.sch |  2 ++
 hercules_samples/matmul/src/gpu.sch |  2 ++
 hercules_samples/matmul/src/main.rs | 27 +++++++++++++++++++-----
 8 files changed, 92 insertions(+), 13 deletions(-)

diff --git a/hercules_cg/src/lib.rs b/hercules_cg/src/lib.rs
index 6910df9e..dab4dbac 100644
--- a/hercules_cg/src/lib.rs
+++ b/hercules_cg/src/lib.rs
@@ -1,6 +1,7 @@
 #![feature(if_let_guard, let_chains)]
 
 pub mod cpu;
+pub mod gpu;
 pub mod rt;
 
 pub mod fork_tree;
@@ -9,6 +10,8 @@ pub use crate::cpu::*;
 pub use crate::gpu::*;
 pub use crate::rt::*;
 
+pub use crate::fork_tree::*;
+
 use std::collections::BTreeMap;
 
 use hercules_ir::*;
diff --git a/hercules_rt/src/lib.rs b/hercules_rt/src/lib.rs
index db2dee77..a23ab3e9 100644
--- a/hercules_rt/src/lib.rs
+++ b/hercules_rt/src/lib.rs
@@ -147,6 +147,19 @@ impl<'a> HerculesCPURefMut<'a> {
 
 #[cfg(feature = "cuda")]
 impl<'a> HerculesCUDARef<'a> {
+    pub fn to_cpu_ref<T>(self, dst: &mut [T]) -> HerculesCPURef<'a> {
+        unsafe {
+            let size = self.size;
+            let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
+            __copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size);
+            HerculesCPURef {
+                ptr,
+                size,
+                _phantom: PhantomData,
+            }
+        }
+    }
+
     pub unsafe fn __ptr(&self) -> *mut u8 {
         self.ptr.as_ptr()
     }
@@ -174,6 +187,19 @@ impl<'a> HerculesCUDARefMut<'a> {
         }
     }
 
+    pub fn to_cpu_ref<T>(self, dst: &mut [T]) -> HerculesCPURef<'a> {
+        unsafe {
+            let size = self.size;
+            let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
+            __copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size);
+            HerculesCPURef {
+                ptr,
+                size,
+                _phantom: PhantomData,
+            }
+        }
+    }
+
     pub unsafe fn __ptr(&self) -> *mut u8 {
         self.ptr.as_ptr()
     }
diff --git a/hercules_samples/dot/src/cpu.sch b/hercules_samples/dot/src/cpu.sch
index 58a7266d..4c684da2 100644
--- a/hercules_samples/dot/src/cpu.sch
+++ b/hercules_samples/dot/src/cpu.sch
@@ -6,7 +6,14 @@ auto-outline(*);
 
 ip-sroa(*);
 sroa(*);
+fork-split(*);
 unforkify(*);
 dce(*);
+float-collections(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+infer-schedules(*);
 
 gcm(*);
diff --git a/hercules_samples/dot/src/gpu.sch b/hercules_samples/dot/src/gpu.sch
index 956eb996..a1a51088 100644
--- a/hercules_samples/dot/src/gpu.sch
+++ b/hercules_samples/dot/src/gpu.sch
@@ -9,5 +9,11 @@ host(dot);
 ip-sroa(*);
 sroa(*);
 dce(*);
+float-collections(*);
+gvn(*);
+phi-elim(*);
+dce(*);
+
+infer-schedules(*);
 
 gcm(*);
diff --git a/hercules_samples/dot/src/main.rs b/hercules_samples/dot/src/main.rs
index 335e8909..4e651fa8 100644
--- a/hercules_samples/dot/src/main.rs
+++ b/hercules_samples/dot/src/main.rs
@@ -1,19 +1,35 @@
 #![feature(concat_idents)]
 
 use hercules_rt::{runner, HerculesCPURef};
+#[cfg(feature = "cuda")]
+use hercules_rt::CUDABox;
 
 juno_build::juno!("dot");
 
 fn main() {
     async_std::task::block_on(async {
-        let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
-        let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
-        let a = HerculesCPURef::from_slice(&a);
-        let b = HerculesCPURef::from_slice(&b);
-        let mut r = runner!(dot);
-        let c = r.run(8, a, b).await;
-        println!("{}", c);
-        assert_eq!(c, 70.0);
+        let mut a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
+        let mut b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
+        #[cfg(not(feature = "cuda"))]
+        {
+            let a = HerculesCPURef::from_slice(&a);
+            let b = HerculesCPURef::from_slice(&b);
+            let mut r = runner!(dot);
+            let c = r.run(8, a, b).await;
+            println!("{}", c);
+            assert_eq!(c, 70.0);
+        }
+        #[cfg(feature = "cuda")]
+        {
+            let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
+            let a = a_box.get_ref();
+            let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
+            let b = b_box.get_ref();
+            let mut r = runner!(dot);
+            let c = r.run(8, a, b).await;
+            println!("{}", c);
+            assert_eq!(c, 70.0);
+        }
     });
 }
 
diff --git a/hercules_samples/matmul/src/cpu.sch b/hercules_samples/matmul/src/cpu.sch
index f7891b9b..4c684da2 100644
--- a/hercules_samples/matmul/src/cpu.sch
+++ b/hercules_samples/matmul/src/cpu.sch
@@ -14,4 +14,6 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
+infer-schedules(*);
+
 gcm(*);
diff --git a/hercules_samples/matmul/src/gpu.sch b/hercules_samples/matmul/src/gpu.sch
index 2bdcc83c..c9d6b336 100644
--- a/hercules_samples/matmul/src/gpu.sch
+++ b/hercules_samples/matmul/src/gpu.sch
@@ -14,4 +14,6 @@ gvn(*);
 phi-elim(*);
 dce(*);
 
+infer-schedules(*);
+
 gcm(*);
diff --git a/hercules_samples/matmul/src/main.rs b/hercules_samples/matmul/src/main.rs
index 8757a0fd..762644f1 100644
--- a/hercules_samples/matmul/src/main.rs
+++ b/hercules_samples/matmul/src/main.rs
@@ -3,6 +3,8 @@
 use rand::random;
 
 use hercules_rt::{runner, HerculesCPURef};
+#[cfg(feature = "cuda")]
+use hercules_rt::CUDABox;
 
 juno_build::juno!("matmul");
 
@@ -21,11 +23,26 @@ fn main() {
                 }
             }
         }
-        let a = HerculesCPURef::from_slice(&mut a);
-        let b = HerculesCPURef::from_slice(&mut b);
-        let mut r = runner!(matmul);
-        let c = r.run(I as u64, J as u64, K as u64, a, b).await;
-        assert_eq!(c.as_slice::<i32>(), &*correct_c);
+        #[cfg(not(feature = "cuda"))]
+        {
+            let a = HerculesCPURef::from_slice(&mut a);
+            let b = HerculesCPURef::from_slice(&mut b);
+            let mut r = runner!(matmul);
+            let c = r.run(I as u64, J as u64, K as u64, a, b).await;
+            assert_eq!(c.as_slice::<i32>(), &*correct_c);
+        }
+        #[cfg(feature = "cuda")]
+        {
+            let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
+            let a = a_box.get_ref();
+            let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
+            let b = b_box.get_ref();
+            let mut r = runner!(matmul);
+            let c = r.run(I as u64, J as u64, K as u64, a, b).await;
+            let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
+            c.to_cpu_ref(&mut c_cpu);
+            assert_eq!(c_cpu.as_ref(), correct_c.as_ref());
+        }
     });
 }
 
-- 
GitLab