提交 fc1ce273 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix elemwise add cuda int8 bcast

GitOrigin-RevId: 568b60e8c9f4d138b57b3f4e715f35cf5ca9d0b4
上级 57bc3657
......@@ -34,9 +34,9 @@ namespace elemwise_intl {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
template <int ndim, typename ctype>
void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv,
int /*grid_size*/,
int /*block_size*/) {
void ParamVisitorBase<ndim, ctype, BCAST_OTHER>::host_init(
const TensorND& rv, int /*grid_size*/, int /*block_size*/,
int /*packed_size*/) {
megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim);
m_ptr = rv.ptr<ctype>();
for (size_t i = 0; i < rv.layout.ndim; ++i) {
......@@ -54,9 +54,10 @@ void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv,
#pragma GCC diagnostic pop
template <typename ctype>
void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv,
void ParamVisitorBase<3, ctype, BCAST_101>::host_init(const TensorND& rv,
int grid_size,
int block_size) {
int block_size,
int packed_size) {
uint32_t shape2, shape1;
int stride1;
if (rv.layout.ndim == 3) {
......@@ -76,9 +77,10 @@ void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv,
}
template <typename ctype>
void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv,
void ParamVisitorBase<2, ctype, BCAST_10>::host_init(const TensorND& rv,
int grid_size,
int block_size) {
int block_size,
int packed_size) {
megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]);
m_ptr = rv.ptr<ctype>();
m_stride1 = rv.layout.stride[1];
......@@ -87,9 +89,10 @@ void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv,
}
template <typename ctype>
void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv,
void ParamVisitorBase<2, ctype, BCAST_01>::host_init(const TensorND& rv,
int grid_size,
int block_size) {
int block_size,
int packed_size) {
megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[1]);
m_ptr = rv.ptr<ctype>();
m_stride0 = rv.layout.stride[0];
......@@ -98,9 +101,10 @@ void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv,
}
template <typename ctype>
void ParamElemVisitor<1, ctype, BCAST_FULL>::host_init(const TensorND& rv,
void ParamVisitorBase<1, ctype, BCAST_FULL>::host_init(const TensorND& rv,
int /*grid_size*/,
int /*block_size*/) {
int /*block_size*/,
int /*packed_size*/) {
megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]);
m_ptr = rv.ptr<ctype>();
}
......@@ -122,6 +126,53 @@ void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv,
m_shape3.host_init(packed_size * grid_size * block_size, shape3);
}
#define INST(ndim, ctype, brd) template class ParamVisitorBase<ndim, ctype, brd>
#define INST_FOR_CTYPE \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
INST(3, ct, BCAST_101); \
INST(2, ct, BCAST_10); \
INST(2, ct, BCAST_01); \
INST(1, ct, BCAST_FULL);
#define ndim_cb(_ndim) INST(_ndim, ct, BCAST_OTHER);
#define ct dt_byte
INST_FOR_CTYPE
#undef ct
#define ct dt_int32
INST_FOR_CTYPE
#undef ct
#define ct dt_float32
INST_FOR_CTYPE
#undef ct
#define ct dt_float16
INST_FOR_CTYPE
#undef ct
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8
INST_FOR_CTYPE
#undef ct
#define ct dt_uint8
INST_FOR_CTYPE
#undef ct
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
#define ct dt_qint8
INST_FOR_CTYPE
#undef ct
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#undef INST_FOR_CTYPE
#undef INST
#define INST(ndim, ctype, brd) template class ParamElemVisitor<ndim, ctype, brd>
#define INST_FOR_CTYPE \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
......
......@@ -142,6 +142,9 @@ INST(dt_qint32, int4);
* ptr()[offset(idx)]
*
*/
template <int ndim, typename ctype, BcastType brd_type>
class ParamVisitorBase;
template <int ndim, typename ctype, BcastType brd_type>
class ParamElemVisitor;
......@@ -157,6 +160,7 @@ class ParamElemVisitor;
* ptr()[offset(idx)]
*
*/
template <int ndim, typename ctype, BcastType brd_type>
class ParamVectVisitor;
......@@ -169,11 +173,9 @@ class ParamVectVisitor;
//! specialization for BCAST_OTHER
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, BCAST_OTHER> {
class ParamVisitorBase<ndim, ctype, BCAST_OTHER> {
protected:
ctype* __restrict m_ptr;
private:
int m_stride[ndim];
//! m_shape_highdim[i] = original_shape[i + 1]
......@@ -185,10 +187,9 @@ private:
public:
static const int NDIM = ndim;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size);
void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {}
......@@ -211,6 +212,18 @@ public:
#endif
};
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, BCAST_OTHER>
: public ParamVisitorBase<ndim, ctype, BCAST_OTHER> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<ndim, ctype, BCAST_OTHER>::host_init(
rv, grid_size, block_size, packed_size);
}
};
/*!
* \brief specialization for ndim == 3 and BCAST_101
* (for dimshuffle 'x', 0, 'x')
......@@ -218,7 +231,7 @@ public:
* visit: idx / m_shape2 % m_shape1
*/
template <typename ctype>
class ParamElemVisitor<3, ctype, BCAST_101> {
class ParamVisitorBase<3, ctype, BCAST_101> {
StridedDivSeq2 m_shape12;
int m_stride1;
......@@ -227,9 +240,9 @@ protected:
public:
static const int NDIM = 3;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size);
void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); }
......@@ -242,13 +255,25 @@ public:
#endif
};
template <typename ctype>
class ParamElemVisitor<3, ctype, BCAST_101>
: public ParamVisitorBase<3, ctype, BCAST_101> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<3, ctype, BCAST_101>::host_init(
rv, grid_size, block_size, packed_size);
}
};
/*!
* \brief specialization for ndim == 2 and BCAST_10
*
* visit: idx % m_shape1
*/
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_10> {
class ParamVisitorBase<2, ctype, BCAST_10> {
StridedDivSeq<false> m_shape1;
int m_stride1;
......@@ -257,9 +282,9 @@ protected:
public:
static const int NDIM = 2;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size);
void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); }
......@@ -272,13 +297,25 @@ public:
#endif
};
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_10>
: public ParamVisitorBase<2, ctype, BCAST_10> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<2, ctype, BCAST_10>::host_init(
rv, grid_size, block_size, packed_size);
}
};
/*!
* \brief specialization for ndim == 2 and BCAST_01
*
* visit: idx / shape1
*/
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_01> {
class ParamVisitorBase<2, ctype, BCAST_01> {
StridedDivSeq<true> m_shape1;
int m_stride0;
......@@ -287,9 +324,9 @@ protected:
public:
static const int NDIM = 2;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size);
void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); }
......@@ -302,9 +339,21 @@ public:
#endif
};
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_01>
: public ParamVisitorBase<2, ctype, BCAST_01> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<2, ctype, BCAST_01>::host_init(
rv, grid_size, block_size, packed_size);
}
};
//! specialization for ndim == 1 and BCAST_FULL
template <typename ctype>
class ParamElemVisitor<1, ctype, BCAST_FULL> {
class ParamVisitorBase<1, ctype, BCAST_FULL> {
protected:
ctype* __restrict m_ptr;
......@@ -312,7 +361,8 @@ public:
static const int NDIM = 1;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size);
void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {}
......@@ -328,6 +378,18 @@ public:
#endif
};
template <typename ctype>
class ParamElemVisitor<1, ctype, BCAST_FULL>
: public ParamVisitorBase<1, ctype, BCAST_FULL> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<1, ctype, BCAST_FULL>::host_init(
rv, grid_size, block_size, packed_size);
}
};
#undef PARAM_ELEM_VISITOR_COMMON_DEV
#undef PARAM_ELEM_VISITOR_COMMON_HOST
......@@ -340,17 +402,21 @@ public:
#else
#define DEVICE_WRAPPER(x)
#endif
#define INST_PARAM_VECT_VISITOR \
template <int ndim, typename ctype> \
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \
: public ParamElemVisitor<ndim, ctype, _brdcast_mask> { \
public: \
using Super = ParamElemVisitor<ndim, ctype, _brdcast_mask>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \
}) \
#define INST_PARAM_VECT_VISITOR \
template <int ndim, typename ctype> \
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \
: public ParamVisitorBase<ndim, ctype, _brdcast_mask> { \
public: \
using Super = ParamVisitorBase<ndim, ctype, _brdcast_mask>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
void host_init(const TensorND& rv, int grid_size, int block_size) { \
ParamVisitorBase<ndim, ctype, _brdcast_mask>::host_init( \
rv, grid_size, block_size, packed_size); \
} \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \
}) \
};
#define _brdcast_mask BCAST_OTHER
INST_PARAM_VECT_VISITOR;
......@@ -367,11 +433,15 @@ INST_PARAM_VECT_VISITOR;
#define INST_DT_IBYTE(ctype) \
template <int ndim> \
class ParamVectVisitor<ndim, ctype, BCAST_FULL> \
: public ParamElemVisitor<ndim, ctype, BCAST_FULL> { \
: public ParamVisitorBase<ndim, ctype, BCAST_FULL> { \
public: \
using Super = ParamElemVisitor<ndim, ctype, BCAST_FULL>; \
using Super = ParamVisitorBase<ndim, ctype, BCAST_FULL>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
void host_init(const TensorND& rv, int grid_size, int block_size) { \
ParamVisitorBase<ndim, ctype, BCAST_FULL>::host_init( \
rv, grid_size, block_size, packed_size); \
} \
DEVICE_WRAPPER(rwtype vect_scalar; \
devfunc rwtype & at(uint32_t /* idx */) { \
ctype v = Super::m_ptr[0]; \
......
......@@ -269,6 +269,43 @@ TEST_F(CUDA, ELEMWISE_BFLOAT16) {
#undef BUILD_TERNARY_COMPLATE_TEST_CASE
}
TEST_F(CUDA, ELEMWISE_ADD_BCAST_10_INT8_INPLACE) {
constexpr size_t A = 2, B = 48, C0 = 14, C1 = 14, C = C0 * C1;
SyncedTensor<dt_int8> t0(handle_cuda(),
{TensorShape{A, B, C0, C1}, dtype::Int8()}),
t1(handle_cuda(), {TensorShape{1, B, C0, C1}, dtype::Int8()}),
t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Int8()});
UniformIntRNG rng{-128, 127};
rng.gen(t0.tensornd_host());
rng.gen(t1.tensornd_host());
auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
auto p2 = t2.ptr_mutable_host();
for (size_t i = 0; i < A; ++i) {
for (size_t j = 0; j < B; ++j) {
for (size_t k = 0; k < C; ++k) {
auto off0 = j * C + k;
auto off1 = i * B * C + j * C + k;
p2[off1] = p0[off1] + p1[off0];
}
}
}
auto opr = handle_cuda()->create_operator<ElemwiseForward>();
opr->param().mode = ElemwiseForward::Mode::ADD;
opr->exec({t0.tensornd_dev(), t1.tensornd_dev()}, t0.tensornd_dev());
auto pt = t0.ptr_host();
for (size_t i = 0; i < A; ++i) {
for (size_t j = 0; j < B; ++j) {
for (size_t k = 0; k < C; ++k) {
auto off = i * B * C + j * C + k;
ASSERT_EQ(pt[off], p2[off]);
}
}
}
}
//! the memory of this test case is too large, sometimes will fail on tx1
TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) {
constexpr size_t A = 256 * 1024 * 64, S0 = 16, S1 = 256, S2 = 64, S3 = 64;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册