From 6082c353e78827bfb1a7264fb73f85de228e1ebf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 30 Apr 2021 13:25:41 +0800 Subject: [PATCH] feat(dnn/rocm): support bool in type_cvt and elemwise GitOrigin-RevId: ad5ec7bc1ce9a0d7588538ad30dbdbdf0b48d640 --- .../rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl | 1 + .../rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip | 5 +++++ dnn/src/rocm/type_cvt/opr_impl.cpp | 2 ++ dnn/src/rocm/type_cvt/type_cvt.cpp.hip | 4 ++++ imperative/python/src/common.cpp | 1 + imperative/src/impl/proxy_graph.cpp | 1 + 6 files changed, 14 insertions(+) diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl index 54112725c..f1804bca2 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl @@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { #define cb(_dtype) \ MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb #undef INST diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip index 6bccf0ef3..26a9fd6f7 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip @@ -38,6 +38,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ((int*)0)[0] = 1; } +__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} + #define KERN_APPLY_OPR_OPR \ ::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr #include "./kern_apply_opr_impl.hipinl" diff --git a/dnn/src/rocm/type_cvt/opr_impl.cpp b/dnn/src/rocm/type_cvt/opr_impl.cpp index 3dfac6b38..52ede4c5e 100644 --- a/dnn/src/rocm/type_cvt/opr_impl.cpp +++ b/dnn/src/rocm/type_cvt/opr_impl.cpp @@ -71,6 +71,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb); + cb(::megdnn::dtype::Bool); #undef cb default: megdnn_assert_internal(0); @@ -106,6 +107,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool); #undef cb default: megdnn_assert_internal(0); diff --git a/dnn/src/rocm/type_cvt/type_cvt.cpp.hip b/dnn/src/rocm/type_cvt/type_cvt.cpp.hip index 22662adf4..67d0e6c2a 100644 --- a/dnn/src/rocm/type_cvt/type_cvt.cpp.hip +++ b/dnn/src/rocm/type_cvt/type_cvt.cpp.hip @@ -136,6 +136,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #if !MEGDNN_DISABLE_FLOAT16 #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ + cb(dtype_src, dt_bool) \ cb(dtype_src, dt_int8) \ cb(dtype_src, dt_int32) \ cb(dtype_src, dt_int16) \ @@ -147,6 +148,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #else #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ + cb(dtype_src, dt_bool) \ cb(dtype_src, dt_int8) \ cb(dtype_src, dt_int32) \ cb(dtype_src, dt_int16) \ @@ -171,6 +173,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #if !MEGDNN_DISABLE_FLOAT16 #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ + cb(dt_bool) \ cb(dt_int8) \ cb(dt_int32) \ cb(dt_int16) \ @@ -181,6 +184,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #else #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ + cb(dt_bool) \ cb(dt_int8) \ cb(dt_int32) \ cb(dt_int16) \ diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 3c2c4e7cb..5a3d98b2a 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -176,6 +176,7 @@ void init_common(py::module m) { py::enum_(m, "DeviceType") .value("UNSPEC", CompNode::DeviceType::UNSPEC) .value("CUDA", CompNode::DeviceType::CUDA) + .value("ROCM", CompNode::DeviceType::ROCM) .value("CPU", CompNode::DeviceType::CPU) .value("CAMBRICON", CompNode::DeviceType::CAMBRICON) .value("ATLAS", CompNode::DeviceType::ATLAS) diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 2263f9f7b..a0c7aaee4 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -378,6 +378,7 @@ public: if (is_finalized()) return; for (auto&& i : m_used_comp_node) { if (i.device_type() == CompNode::DeviceType::CUDA) continue; + if (i.device_type() == CompNode::DeviceType::ROCM) continue; i.sync(); } } -- GitLab