提交 6082c353 编写于 作者: M Megvii Engine Team

feat(dnn/rocm): support bool in type_cvt and elemwise

GitOrigin-RevId: ad5ec7bc1ce9a0d7588538ad30dbdbdf0b48d640
上级 0ad377c7
......@@ -72,6 +72,7 @@ namespace indexing_multi_axis_vec {
#define cb(_dtype) \
MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype)
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool)
#undef cb
#undef INST
......
......@@ -38,6 +38,11 @@ __device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) {
((int*)0)[0] = 1;
}
__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) {
asm("s_trap 2;");
((int*)0)[0] = 1;
}
#define KERN_APPLY_OPR_OPR \
::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr
#include "./kern_apply_opr_impl.hipinl"
......
......@@ -71,6 +71,7 @@ void exec_src_normal(const TensorND& dst, const TensorND& src,
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
cb(::megdnn::dtype::Bool);
#undef cb
default:
megdnn_assert_internal(0);
......@@ -106,6 +107,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
cb(::megdnn::dtype::Bool);
#undef cb
default:
megdnn_assert_internal(0);
......
......@@ -136,6 +136,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#if !MEGDNN_DISABLE_FLOAT16
#define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \
cb(dtype_src, dt_bool) \
cb(dtype_src, dt_int8) \
cb(dtype_src, dt_int32) \
cb(dtype_src, dt_int16) \
......@@ -147,6 +148,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#else
#define MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, cb) \
cb(dtype_src, dt_bool) \
cb(dtype_src, dt_int8) \
cb(dtype_src, dt_int32) \
cb(dtype_src, dt_int16) \
......@@ -171,6 +173,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#if !MEGDNN_DISABLE_FLOAT16
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
cb(dt_bool) \
cb(dt_int8) \
cb(dt_int32) \
cb(dt_int16) \
......@@ -181,6 +184,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src,
#else
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
cb(dt_bool) \
cb(dt_int8) \
cb(dt_int32) \
cb(dt_int16) \
......
......@@ -176,6 +176,7 @@ void init_common(py::module m) {
py::enum_<CompNode::DeviceType>(m, "DeviceType")
.value("UNSPEC", CompNode::DeviceType::UNSPEC)
.value("CUDA", CompNode::DeviceType::CUDA)
.value("ROCM", CompNode::DeviceType::ROCM)
.value("CPU", CompNode::DeviceType::CPU)
.value("CAMBRICON", CompNode::DeviceType::CAMBRICON)
.value("ATLAS", CompNode::DeviceType::ATLAS)
......
......@@ -378,6 +378,7 @@ public:
if (is_finalized()) return;
for (auto&& i : m_used_comp_node) {
if (i.device_type() == CompNode::DeviceType::CUDA) continue;
if (i.device_type() == CompNode::DeviceType::ROCM) continue;
i.sync();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册