diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl index 54112725c51adad5dc1eeb075d4cf739d51a1870..f1804bca2d0629b1ebf1eb0dd1f54890d119d005 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl @@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { #define cb(_dtype) \ MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool) #undef cb #undef INST diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip index 6bccf0ef3c1baba2b01460ed0196e57963089fd5..26a9fd6f739ddb36b7203770d238c23bf74b1fb9 100644 --- a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip @@ -38,6 +38,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ((int*)0)[0] = 1; } +__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} + #define KERN_APPLY_OPR_OPR \ ::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr #include "./kern_apply_opr_impl.hipinl" diff --git a/dnn/src/rocm/type_cvt/opr_impl.cpp b/dnn/src/rocm/type_cvt/opr_impl.cpp index 3dfac6b386b5fa52ce684611dd27f84a8e9e8088..52ede4c5ef4fb0fb24e9b35c35816d9b0f721c3a 100644 --- a/dnn/src/rocm/type_cvt/opr_impl.cpp +++ b/dnn/src/rocm/type_cvt/opr_impl.cpp @@ -71,6 +71,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb); + cb(::megdnn::dtype::Bool); #undef cb default: megdnn_assert_internal(0); @@ -106,6 +107,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + cb(::megdnn::dtype::Bool); #undef cb default: megdnn_assert_internal(0); diff --git a/dnn/src/rocm/type_cvt/type_cvt.cpp.hip b/dnn/src/rocm/type_cvt/type_cvt.cpp.hip index 22662adf4f8fa025f069eb17371dce0714da486a..67d0e6c2a9ddf78836fef9e72f033e883bd01001 100644 --- a/dnn/src/rocm/type_cvt/type_cvt.cpp.hip +++ b/dnn/src/rocm/type_cvt/type_cvt.cpp.hip @@ -136,6 +136,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #if !MEGDNN_DISABLE_FLOAT16 #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ + cb(dtype_src, dt_bool) \ cb(dtype_src, dt_int8) \ cb(dtype_src, dt_int32) \ cb(dtype_src, dt_int16) \ @@ -147,6 +148,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #else #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ + cb(dtype_src, dt_bool) \ cb(dtype_src, dt_int8) \ cb(dtype_src, dt_int32) \ cb(dtype_src, dt_int16) \ @@ -171,6 +173,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #if !MEGDNN_DISABLE_FLOAT16 #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ + cb(dt_bool) \ cb(dt_int8) \ cb(dt_int32) \ cb(dt_int16) \ @@ -181,6 +184,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, #else #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ + cb(dt_bool) \ cb(dt_int8) \ cb(dt_int32) \ cb(dt_int16) \ diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 3c2c4e7cbad6feb07a57ac62033a64b100eff969..5a3d98b2ade40dd016fef843c1d44fbfac2608b4 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -176,6 +176,7 @@ void init_common(py::module m) { py::enum_(m, "DeviceType") .value("UNSPEC", CompNode::DeviceType::UNSPEC) .value("CUDA", CompNode::DeviceType::CUDA) + .value("ROCM", CompNode::DeviceType::ROCM) .value("CPU", CompNode::DeviceType::CPU) .value("CAMBRICON", CompNode::DeviceType::CAMBRICON) .value("ATLAS", CompNode::DeviceType::ATLAS) diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 2263f9f7b5247ccab0ad1f6cfb01b16a8ea9dea2..a0c7aaee41c7942e5a28e44b8f18b06fd0783b50 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -378,6 +378,7 @@ public: if (is_finalized()) return; for (auto&& i : m_used_comp_node) { if (i.device_type() == CompNode::DeviceType::CUDA) continue; + if (i.device_type() == CompNode::DeviceType::ROCM) continue; i.sync(); } }