Skip to content
Snippets Groups Projects
Commit a4ded058 authored by prrathi's avatar prrathi
Browse files

matmul dot chngs

parent 41dfb2ab
No related branches found
No related tags found
1 merge request!115GPU backend
#![feature(if_let_guard, let_chains)]
pub mod cpu;
pub mod gpu;
pub mod rt;
pub mod fork_tree;
......@@ -9,6 +10,8 @@ pub use crate::cpu::*;
pub use crate::gpu::*;
pub use crate::rt::*;
pub use crate::fork_tree::*;
use std::collections::BTreeMap;
use hercules_ir::*;
......
......@@ -147,6 +147,19 @@ impl<'a> HerculesCPURefMut<'a> {
#[cfg(feature = "cuda")]
impl<'a> HerculesCUDARef<'a> {
pub fn to_cpu_ref<T>(self, dst: &mut [T]) -> HerculesCPURef<'a> {
unsafe {
let size = self.size;
let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
__copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size);
HerculesCPURef {
ptr,
size,
_phantom: PhantomData,
}
}
}
pub unsafe fn __ptr(&self) -> *mut u8 {
self.ptr.as_ptr()
}
......@@ -174,6 +187,19 @@ impl<'a> HerculesCUDARefMut<'a> {
}
}
pub fn to_cpu_ref<T>(self, dst: &mut [T]) -> HerculesCPURef<'a> {
unsafe {
let size = self.size;
let ptr = NonNull::new(dst.as_ptr() as *mut u8).unwrap();
__copy_cuda_to_cpu(ptr.as_ptr(), self.ptr.as_ptr(), size);
HerculesCPURef {
ptr,
size,
_phantom: PhantomData,
}
}
}
pub unsafe fn __ptr(&self) -> *mut u8 {
self.ptr.as_ptr()
}
......
......@@ -6,7 +6,14 @@ auto-outline(*);
ip-sroa(*);
sroa(*);
fork-split(*);
unforkify(*);
dce(*);
float-collections(*);
gvn(*);
phi-elim(*);
dce(*);
infer-schedules(*);
gcm(*);
......@@ -9,5 +9,11 @@ host(dot);
ip-sroa(*);
sroa(*);
dce(*);
float-collections(*);
gvn(*);
phi-elim(*);
dce(*);
infer-schedules(*);
gcm(*);
#![feature(concat_idents)]
use hercules_rt::{runner, HerculesCPURef};
#[cfg(feature = "cuda")]
use hercules_rt::CUDABox;
juno_build::juno!("dot");
fn main() {
async_std::task::block_on(async {
let a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
let b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
let a = HerculesCPURef::from_slice(&a);
let b = HerculesCPURef::from_slice(&b);
let mut r = runner!(dot);
let c = r.run(8, a, b).await;
println!("{}", c);
assert_eq!(c, 70.0);
let mut a: [f32; 8] = [0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0];
let mut b: [f32; 8] = [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0];
#[cfg(not(feature = "cuda"))]
{
let a = HerculesCPURef::from_slice(&a);
let b = HerculesCPURef::from_slice(&b);
let mut r = runner!(dot);
let c = r.run(8, a, b).await;
println!("{}", c);
assert_eq!(c, 70.0);
}
#[cfg(feature = "cuda")]
{
let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
let a = a_box.get_ref();
let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
let b = b_box.get_ref();
let mut r = runner!(dot);
let c = r.run(8, a, b).await;
println!("{}", c);
assert_eq!(c, 70.0);
}
});
}
......
......@@ -14,4 +14,6 @@ gvn(*);
phi-elim(*);
dce(*);
infer-schedules(*);
gcm(*);
......@@ -14,4 +14,6 @@ gvn(*);
phi-elim(*);
dce(*);
infer-schedules(*);
gcm(*);
......@@ -3,6 +3,8 @@
use rand::random;
use hercules_rt::{runner, HerculesCPURef};
#[cfg(feature = "cuda")]
use hercules_rt::CUDABox;
juno_build::juno!("matmul");
......@@ -21,11 +23,26 @@ fn main() {
}
}
}
let a = HerculesCPURef::from_slice(&mut a);
let b = HerculesCPURef::from_slice(&mut b);
let mut r = runner!(matmul);
let c = r.run(I as u64, J as u64, K as u64, a, b).await;
assert_eq!(c.as_slice::<i32>(), &*correct_c);
#[cfg(not(feature = "cuda"))]
{
let a = HerculesCPURef::from_slice(&mut a);
let b = HerculesCPURef::from_slice(&mut b);
let mut r = runner!(matmul);
let c = r.run(I as u64, J as u64, K as u64, a, b).await;
assert_eq!(c.as_slice::<i32>(), &*correct_c);
}
#[cfg(feature = "cuda")]
{
let a_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut a));
let a = a_box.get_ref();
let b_box = CUDABox::from_cpu_ref(HerculesCPURef::from_slice(&mut b));
let b = b_box.get_ref();
let mut r = runner!(matmul);
let c = r.run(I as u64, J as u64, K as u64, a, b).await;
let mut c_cpu: Box<[i32]> = vec![0; correct_c.len()].into_boxed_slice();
c.to_cpu_ref(&mut c_cpu);
assert_eq!(c_cpu.as_ref(), correct_c.as_ref());
}
});
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment