提交 fc1ce273 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix elemwise add cuda int8 bcast

GitOrigin-RevId: 568b60e8c9f4d138b57b3f4e715f35cf5ca9d0b4
上级 57bc3657
...@@ -34,9 +34,9 @@ namespace elemwise_intl { ...@@ -34,9 +34,9 @@ namespace elemwise_intl {
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds" #pragma GCC diagnostic ignored "-Warray-bounds"
template <int ndim, typename ctype> template <int ndim, typename ctype>
void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv, void ParamVisitorBase<ndim, ctype, BCAST_OTHER>::host_init(
int /*grid_size*/, const TensorND& rv, int /*grid_size*/, int /*block_size*/,
int /*block_size*/) { int /*packed_size*/) {
megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim);
m_ptr = rv.ptr<ctype>(); m_ptr = rv.ptr<ctype>();
for (size_t i = 0; i < rv.layout.ndim; ++i) { for (size_t i = 0; i < rv.layout.ndim; ++i) {
...@@ -54,9 +54,10 @@ void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv, ...@@ -54,9 +54,10 @@ void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv,
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
template <typename ctype> template <typename ctype>
void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv, void ParamVisitorBase<3, ctype, BCAST_101>::host_init(const TensorND& rv,
int grid_size, int grid_size,
int block_size) { int block_size,
int packed_size) {
uint32_t shape2, shape1; uint32_t shape2, shape1;
int stride1; int stride1;
if (rv.layout.ndim == 3) { if (rv.layout.ndim == 3) {
...@@ -76,9 +77,10 @@ void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv, ...@@ -76,9 +77,10 @@ void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv,
} }
template <typename ctype> template <typename ctype>
void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv, void ParamVisitorBase<2, ctype, BCAST_10>::host_init(const TensorND& rv,
int grid_size, int grid_size,
int block_size) { int block_size,
int packed_size) {
megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]); megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]);
m_ptr = rv.ptr<ctype>(); m_ptr = rv.ptr<ctype>();
m_stride1 = rv.layout.stride[1]; m_stride1 = rv.layout.stride[1];
...@@ -87,9 +89,10 @@ void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv, ...@@ -87,9 +89,10 @@ void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv,
} }
template <typename ctype> template <typename ctype>
void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv, void ParamVisitorBase<2, ctype, BCAST_01>::host_init(const TensorND& rv,
int grid_size, int grid_size,
int block_size) { int block_size,
int packed_size) {
megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[1]); megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[1]);
m_ptr = rv.ptr<ctype>(); m_ptr = rv.ptr<ctype>();
m_stride0 = rv.layout.stride[0]; m_stride0 = rv.layout.stride[0];
...@@ -98,9 +101,10 @@ void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv, ...@@ -98,9 +101,10 @@ void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv,
} }
template <typename ctype> template <typename ctype>
void ParamElemVisitor<1, ctype, BCAST_FULL>::host_init(const TensorND& rv, void ParamVisitorBase<1, ctype, BCAST_FULL>::host_init(const TensorND& rv,
int /*grid_size*/, int /*grid_size*/,
int /*block_size*/) { int /*block_size*/,
int /*packed_size*/) {
megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]); megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]);
m_ptr = rv.ptr<ctype>(); m_ptr = rv.ptr<ctype>();
} }
...@@ -122,6 +126,53 @@ void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv, ...@@ -122,6 +126,53 @@ void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv,
m_shape3.host_init(packed_size * grid_size * block_size, shape3); m_shape3.host_init(packed_size * grid_size * block_size, shape3);
} }
#define INST(ndim, ctype, brd) template class ParamVisitorBase<ndim, ctype, brd>
#define INST_FOR_CTYPE \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
INST(3, ct, BCAST_101); \
INST(2, ct, BCAST_10); \
INST(2, ct, BCAST_01); \
INST(1, ct, BCAST_FULL);
#define ndim_cb(_ndim) INST(_ndim, ct, BCAST_OTHER);
#define ct dt_byte
INST_FOR_CTYPE
#undef ct
#define ct dt_int32
INST_FOR_CTYPE
#undef ct
#define ct dt_float32
INST_FOR_CTYPE
#undef ct
#define ct dt_float16
INST_FOR_CTYPE
#undef ct
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8
INST_FOR_CTYPE
#undef ct
#define ct dt_uint8
INST_FOR_CTYPE
#undef ct
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
#define ct dt_qint8
INST_FOR_CTYPE
#undef ct
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#undef INST_FOR_CTYPE
#undef INST
#define INST(ndim, ctype, brd) template class ParamElemVisitor<ndim, ctype, brd> #define INST(ndim, ctype, brd) template class ParamElemVisitor<ndim, ctype, brd>
#define INST_FOR_CTYPE \ #define INST_FOR_CTYPE \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \ MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
......
...@@ -142,6 +142,9 @@ INST(dt_qint32, int4); ...@@ -142,6 +142,9 @@ INST(dt_qint32, int4);
* ptr()[offset(idx)] * ptr()[offset(idx)]
* *
*/ */
template <int ndim, typename ctype, BcastType brd_type>
class ParamVisitorBase;
template <int ndim, typename ctype, BcastType brd_type> template <int ndim, typename ctype, BcastType brd_type>
class ParamElemVisitor; class ParamElemVisitor;
...@@ -157,6 +160,7 @@ class ParamElemVisitor; ...@@ -157,6 +160,7 @@ class ParamElemVisitor;
* ptr()[offset(idx)] * ptr()[offset(idx)]
* *
*/ */
template <int ndim, typename ctype, BcastType brd_type> template <int ndim, typename ctype, BcastType brd_type>
class ParamVectVisitor; class ParamVectVisitor;
...@@ -169,11 +173,9 @@ class ParamVectVisitor; ...@@ -169,11 +173,9 @@ class ParamVectVisitor;
//! specialization for BCAST_OTHER //! specialization for BCAST_OTHER
template <int ndim, typename ctype> template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, BCAST_OTHER> { class ParamVisitorBase<ndim, ctype, BCAST_OTHER> {
protected: protected:
ctype* __restrict m_ptr; ctype* __restrict m_ptr;
private:
int m_stride[ndim]; int m_stride[ndim];
//! m_shape_highdim[i] = original_shape[i + 1] //! m_shape_highdim[i] = original_shape[i + 1]
...@@ -185,10 +187,9 @@ private: ...@@ -185,10 +187,9 @@ private:
public: public:
static const int NDIM = ndim; static const int NDIM = ndim;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size);
void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {} devfunc void thread_init(uint32_t) {}
...@@ -211,6 +212,18 @@ public: ...@@ -211,6 +212,18 @@ public:
#endif #endif
}; };
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, BCAST_OTHER>
: public ParamVisitorBase<ndim, ctype, BCAST_OTHER> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<ndim, ctype, BCAST_OTHER>::host_init(
rv, grid_size, block_size, packed_size);
}
};
/*! /*!
* \brief specialization for ndim == 3 and BCAST_101 * \brief specialization for ndim == 3 and BCAST_101
* (for dimshuffle 'x', 0, 'x') * (for dimshuffle 'x', 0, 'x')
...@@ -218,7 +231,7 @@ public: ...@@ -218,7 +231,7 @@ public:
* visit: idx / m_shape2 % m_shape1 * visit: idx / m_shape2 % m_shape1
*/ */
template <typename ctype> template <typename ctype>
class ParamElemVisitor<3, ctype, BCAST_101> { class ParamVisitorBase<3, ctype, BCAST_101> {
StridedDivSeq2 m_shape12; StridedDivSeq2 m_shape12;
int m_stride1; int m_stride1;
...@@ -227,9 +240,9 @@ protected: ...@@ -227,9 +240,9 @@ protected:
public: public:
static const int NDIM = 3; static const int NDIM = 3;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size); void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); } devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); }
...@@ -242,13 +255,25 @@ public: ...@@ -242,13 +255,25 @@ public:
#endif #endif
}; };
template <typename ctype>
class ParamElemVisitor<3, ctype, BCAST_101>
: public ParamVisitorBase<3, ctype, BCAST_101> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<3, ctype, BCAST_101>::host_init(
rv, grid_size, block_size, packed_size);
}
};
/*! /*!
* \brief specialization for ndim == 2 and BCAST_10 * \brief specialization for ndim == 2 and BCAST_10
* *
* visit: idx % m_shape1 * visit: idx % m_shape1
*/ */
template <typename ctype> template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_10> { class ParamVisitorBase<2, ctype, BCAST_10> {
StridedDivSeq<false> m_shape1; StridedDivSeq<false> m_shape1;
int m_stride1; int m_stride1;
...@@ -257,9 +282,9 @@ protected: ...@@ -257,9 +282,9 @@ protected:
public: public:
static const int NDIM = 2; static const int NDIM = 2;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size); void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); }
...@@ -272,13 +297,25 @@ public: ...@@ -272,13 +297,25 @@ public:
#endif #endif
}; };
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_10>
: public ParamVisitorBase<2, ctype, BCAST_10> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<2, ctype, BCAST_10>::host_init(
rv, grid_size, block_size, packed_size);
}
};
/*! /*!
* \brief specialization for ndim == 2 and BCAST_01 * \brief specialization for ndim == 2 and BCAST_01
* *
* visit: idx / shape1 * visit: idx / shape1
*/ */
template <typename ctype> template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_01> { class ParamVisitorBase<2, ctype, BCAST_01> {
StridedDivSeq<true> m_shape1; StridedDivSeq<true> m_shape1;
int m_stride0; int m_stride0;
...@@ -287,9 +324,9 @@ protected: ...@@ -287,9 +324,9 @@ protected:
public: public:
static const int NDIM = 2; static const int NDIM = 2;
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size); void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); }
...@@ -302,9 +339,21 @@ public: ...@@ -302,9 +339,21 @@ public:
#endif #endif
}; };
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_01>
: public ParamVisitorBase<2, ctype, BCAST_01> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<2, ctype, BCAST_01>::host_init(
rv, grid_size, block_size, packed_size);
}
};
//! specialization for ndim == 1 and BCAST_FULL //! specialization for ndim == 1 and BCAST_FULL
template <typename ctype> template <typename ctype>
class ParamElemVisitor<1, ctype, BCAST_FULL> { class ParamVisitorBase<1, ctype, BCAST_FULL> {
protected: protected:
ctype* __restrict m_ptr; ctype* __restrict m_ptr;
...@@ -312,7 +361,8 @@ public: ...@@ -312,7 +361,8 @@ public:
static const int NDIM = 1; static const int NDIM = 1;
PARAM_ELEM_VISITOR_COMMON_HOST PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size); void host_init(const TensorND& rv, int grid_size, int block_size,
int packed_size);
#if MEGDNN_CC_CUDA #if MEGDNN_CC_CUDA
devfunc void thread_init(uint32_t) {} devfunc void thread_init(uint32_t) {}
...@@ -328,6 +378,18 @@ public: ...@@ -328,6 +378,18 @@ public:
#endif #endif
}; };
template <typename ctype>
class ParamElemVisitor<1, ctype, BCAST_FULL>
: public ParamVisitorBase<1, ctype, BCAST_FULL> {
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void host_init(const TensorND& rv, int grid_size, int block_size) {
ParamVisitorBase<1, ctype, BCAST_FULL>::host_init(
rv, grid_size, block_size, packed_size);
}
};
#undef PARAM_ELEM_VISITOR_COMMON_DEV #undef PARAM_ELEM_VISITOR_COMMON_DEV
#undef PARAM_ELEM_VISITOR_COMMON_HOST #undef PARAM_ELEM_VISITOR_COMMON_HOST
...@@ -340,17 +402,21 @@ public: ...@@ -340,17 +402,21 @@ public:
#else #else
#define DEVICE_WRAPPER(x) #define DEVICE_WRAPPER(x)
#endif #endif
#define INST_PARAM_VECT_VISITOR \ #define INST_PARAM_VECT_VISITOR \
template <int ndim, typename ctype> \ template <int ndim, typename ctype> \
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \ class ParamVectVisitor<ndim, ctype, _brdcast_mask> \
: public ParamElemVisitor<ndim, ctype, _brdcast_mask> { \ : public ParamVisitorBase<ndim, ctype, _brdcast_mask> { \
public: \ public: \
using Super = ParamElemVisitor<ndim, ctype, _brdcast_mask>; \ using Super = ParamVisitorBase<ndim, ctype, _brdcast_mask>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \ void host_init(const TensorND& rv, int grid_size, int block_size) { \
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \ ParamVisitorBase<ndim, ctype, _brdcast_mask>::host_init( \
}) \ rv, grid_size, block_size, packed_size); \
} \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \
}) \
}; };
#define _brdcast_mask BCAST_OTHER #define _brdcast_mask BCAST_OTHER
INST_PARAM_VECT_VISITOR; INST_PARAM_VECT_VISITOR;
...@@ -367,11 +433,15 @@ INST_PARAM_VECT_VISITOR; ...@@ -367,11 +433,15 @@ INST_PARAM_VECT_VISITOR;
#define INST_DT_IBYTE(ctype) \ #define INST_DT_IBYTE(ctype) \
template <int ndim> \ template <int ndim> \
class ParamVectVisitor<ndim, ctype, BCAST_FULL> \ class ParamVectVisitor<ndim, ctype, BCAST_FULL> \
: public ParamElemVisitor<ndim, ctype, BCAST_FULL> { \ : public ParamVisitorBase<ndim, ctype, BCAST_FULL> { \
public: \ public: \
using Super = ParamElemVisitor<ndim, ctype, BCAST_FULL>; \ using Super = ParamVisitorBase<ndim, ctype, BCAST_FULL>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \ using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
void host_init(const TensorND& rv, int grid_size, int block_size) { \
ParamVisitorBase<ndim, ctype, BCAST_FULL>::host_init( \
rv, grid_size, block_size, packed_size); \
} \
DEVICE_WRAPPER(rwtype vect_scalar; \ DEVICE_WRAPPER(rwtype vect_scalar; \
devfunc rwtype & at(uint32_t /* idx */) { \ devfunc rwtype & at(uint32_t /* idx */) { \
ctype v = Super::m_ptr[0]; \ ctype v = Super::m_ptr[0]; \
......
...@@ -269,6 +269,43 @@ TEST_F(CUDA, ELEMWISE_BFLOAT16) { ...@@ -269,6 +269,43 @@ TEST_F(CUDA, ELEMWISE_BFLOAT16) {
#undef BUILD_TERNARY_COMPLATE_TEST_CASE #undef BUILD_TERNARY_COMPLATE_TEST_CASE
} }
TEST_F(CUDA, ELEMWISE_ADD_BCAST_10_INT8_INPLACE) {
constexpr size_t A = 2, B = 48, C0 = 14, C1 = 14, C = C0 * C1;
SyncedTensor<dt_int8> t0(handle_cuda(),
{TensorShape{A, B, C0, C1}, dtype::Int8()}),
t1(handle_cuda(), {TensorShape{1, B, C0, C1}, dtype::Int8()}),
t2(handle_cuda(), {TensorShape{A, B, C0, C1}, dtype::Int8()});
UniformIntRNG rng{-128, 127};
rng.gen(t0.tensornd_host());
rng.gen(t1.tensornd_host());
auto p0 = t0.ptr_host(), p1 = t1.ptr_host();
auto p2 = t2.ptr_mutable_host();
for (size_t i = 0; i < A; ++i) {
for (size_t j = 0; j < B; ++j) {
for (size_t k = 0; k < C; ++k) {
auto off0 = j * C + k;
auto off1 = i * B * C + j * C + k;
p2[off1] = p0[off1] + p1[off0];
}
}
}
auto opr = handle_cuda()->create_operator<ElemwiseForward>();
opr->param().mode = ElemwiseForward::Mode::ADD;
opr->exec({t0.tensornd_dev(), t1.tensornd_dev()}, t0.tensornd_dev());
auto pt = t0.ptr_host();
for (size_t i = 0; i < A; ++i) {
for (size_t j = 0; j < B; ++j) {
for (size_t k = 0; k < C; ++k) {
auto off = i * B * C + j * C + k;
ASSERT_EQ(pt[off], p2[off]);
}
}
}
}
//! the memory of this test case is too large, sometimes will fail on tx1 //! the memory of this test case is too large, sometimes will fail on tx1
TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) { TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) {
constexpr size_t A = 256 * 1024 * 64, S0 = 16, S1 = 256, S2 = 64, S3 = 64; constexpr size_t A = 256 * 1024 * 64, S0 = 16, S1 = 256, S2 = 64, S3 = 64;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册