提交 6082c353 编写于 作者: M Megvii Engine Team

feat(dnn/rocm): support bool in type_cvt and elemwise

GitOrigin-RevId: ad5ec7bc1ce9a0d7588538ad30dbdbdf0b48d640
上级 0ad377c7
...@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec { ...@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec {
#define cb(_dtype) \ #define cb(_dtype) \
MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb #undef cb
#undef INST #undef INST
......
...@@ -38,6 +38,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { ...@@ -38,6 +38,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) {
((int*)0)[0] = 1; ((int*)0)[0] = 1;
} }
__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) {
asm("s_trap 2;");
((int*)0)[0] = 1;
}
#define KERN_APPLY_OPR_OPR \ #define KERN_APPLY_OPR_OPR \
::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr ::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr
#include "./kern_apply_opr_impl.hipinl" #include "./kern_apply_opr_impl.hipinl"
......
...@@ -71,6 +71,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, ...@@ -71,6 +71,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src,
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb); MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
cb(::megdnn::dtype::Bool);
#undef cb #undef cb
default: default:
megdnn_assert_internal(0); megdnn_assert_internal(0);
...@@ -106,6 +107,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { ...@@ -106,6 +107,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \ return; \
} }
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool);
#undef cb #undef cb
default: default:
megdnn_assert_internal(0); megdnn_assert_internal(0);
......
...@@ -136,6 +136,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, ...@@ -136,6 +136,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
#define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \
cb(dtype_src, dt_bool) \
cb(dtype_src, dt_int8) \ cb(dtype_src, dt_int8) \
cb(dtype_src, dt_int32) \ cb(dtype_src, dt_int32) \
cb(dtype_src, dt_int16) \ cb(dtype_src, dt_int16) \
...@@ -147,6 +148,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, ...@@ -147,6 +148,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#else #else
#define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \ #define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \
cb(dtype_src, dt_bool) \
cb(dtype_src, dt_int8) \ cb(dtype_src, dt_int8) \
cb(dtype_src, dt_int32) \ cb(dtype_src, dt_int32) \
cb(dtype_src, dt_int16) \ cb(dtype_src, dt_int16) \
...@@ -171,6 +173,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, ...@@ -171,6 +173,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
cb(dt_bool) \
cb(dt_int8) \ cb(dt_int8) \
cb(dt_int32) \ cb(dt_int32) \
cb(dt_int16) \ cb(dt_int16) \
...@@ -181,6 +184,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, ...@@ -181,6 +184,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#else #else
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
cb(dt_bool) \
cb(dt_int8) \ cb(dt_int8) \
cb(dt_int32) \ cb(dt_int32) \
cb(dt_int16) \ cb(dt_int16) \
......
...@@ -176,6 +176,7 @@ void init_common(py::module m) { ...@@ -176,6 +176,7 @@ void init_common(py::module m) {
py::enum_<CompNode::DeviceType>(m, "DeviceType") py::enum_<CompNode::DeviceType>(m, "DeviceType")
.value("UNSPEC", CompNode::DeviceType::UNSPEC) .value("UNSPEC", CompNode::DeviceType::UNSPEC)
.value("CUDA", CompNode::DeviceType::CUDA) .value("CUDA", CompNode::DeviceType::CUDA)
.value("ROCM", CompNode::DeviceType::ROCM)
.value("CPU", CompNode::DeviceType::CPU) .value("CPU", CompNode::DeviceType::CPU)
.value("CAMBRICON", CompNode::DeviceType::CAMBRICON) .value("CAMBRICON", CompNode::DeviceType::CAMBRICON)
.value("ATLAS", CompNode::DeviceType::ATLAS) .value("ATLAS", CompNode::DeviceType::ATLAS)
......
...@@ -378,6 +378,7 @@ public: ...@@ -378,6 +378,7 @@ public:
if (is_finalized()) return; if (is_finalized()) return;
for (auto&& i : m_used_comp_node) { for (auto&& i : m_used_comp_node) {
if (i.device_type() == CompNode::DeviceType::CUDA) continue; if (i.device_type() == CompNode::DeviceType::CUDA) continue;
if (i.device_type() == CompNode::DeviceType::ROCM) continue;
i.sync(); i.sync();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册