Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
fc1ce273
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
fc1ce273
编写于
7月 16, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn/cuda): fix elemwise add cuda int8 bcast
GitOrigin-RevId: 568b60e8c9f4d138b57b3f4e715f35cf5ca9d0b4
上级
57bc3657
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
199 addition
and
41 deletion
+199
-41
dnn/src/cuda/elemwise_helper.cpp
dnn/src/cuda/elemwise_helper.cpp
+62
-11
dnn/src/cuda/elemwise_helper.cuh
dnn/src/cuda/elemwise_helper.cuh
+100
-30
dnn/test/cuda/elemwise.cpp
dnn/test/cuda/elemwise.cpp
+37
-0
未找到文件。
dnn/src/cuda/elemwise_helper.cpp
浏览文件 @
fc1ce273
...
@@ -34,9 +34,9 @@ namespace elemwise_intl {
...
@@ -34,9 +34,9 @@ namespace elemwise_intl {
#pragma GCC diagnostic push
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
#pragma GCC diagnostic ignored "-Warray-bounds"
template
<
int
ndim
,
typename
ctype
>
template
<
int
ndim
,
typename
ctype
>
void
Param
ElemVisitor
<
ndim
,
ctype
,
BCAST_OTHER
>::
host_init
(
const
TensorND
&
rv
,
void
Param
VisitorBase
<
ndim
,
ctype
,
BCAST_OTHER
>::
host_init
(
int
/*grid
_size*/
,
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block
_size*/
,
int
/*block
_size*/
)
{
int
/*packed
_size*/
)
{
megdnn_assert
(
rv
.
layout
.
ndim
&&
rv
.
layout
.
ndim
<=
ndim
);
megdnn_assert
(
rv
.
layout
.
ndim
&&
rv
.
layout
.
ndim
<=
ndim
);
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_ptr
=
rv
.
ptr
<
ctype
>
();
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
...
@@ -54,9 +54,10 @@ void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv,
...
@@ -54,9 +54,10 @@ void ParamElemVisitor<ndim, ctype, BCAST_OTHER>::host_init(const TensorND& rv,
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop
template
<
typename
ctype
>
template
<
typename
ctype
>
void
Param
ElemVisitor
<
3
,
ctype
,
BCAST_101
>::
host_init
(
const
TensorND
&
rv
,
void
Param
VisitorBase
<
3
,
ctype
,
BCAST_101
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
grid_size
,
int
block_size
)
{
int
block_size
,
int
packed_size
)
{
uint32_t
shape2
,
shape1
;
uint32_t
shape2
,
shape1
;
int
stride1
;
int
stride1
;
if
(
rv
.
layout
.
ndim
==
3
)
{
if
(
rv
.
layout
.
ndim
==
3
)
{
...
@@ -76,9 +77,10 @@ void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv,
...
@@ -76,9 +77,10 @@ void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv,
}
}
template
<
typename
ctype
>
template
<
typename
ctype
>
void
Param
ElemVisitor
<
2
,
ctype
,
BCAST_10
>::
host_init
(
const
TensorND
&
rv
,
void
Param
VisitorBase
<
2
,
ctype
,
BCAST_10
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
grid_size
,
int
block_size
)
{
int
block_size
,
int
packed_size
)
{
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
0
]);
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
0
]);
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_stride1
=
rv
.
layout
.
stride
[
1
];
m_stride1
=
rv
.
layout
.
stride
[
1
];
...
@@ -87,9 +89,10 @@ void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv,
...
@@ -87,9 +89,10 @@ void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv,
}
}
template
<
typename
ctype
>
template
<
typename
ctype
>
void
Param
ElemVisitor
<
2
,
ctype
,
BCAST_01
>::
host_init
(
const
TensorND
&
rv
,
void
Param
VisitorBase
<
2
,
ctype
,
BCAST_01
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
grid_size
,
int
block_size
)
{
int
block_size
,
int
packed_size
)
{
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
1
]);
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
1
]);
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_stride0
=
rv
.
layout
.
stride
[
0
];
m_stride0
=
rv
.
layout
.
stride
[
0
];
...
@@ -98,9 +101,10 @@ void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv,
...
@@ -98,9 +101,10 @@ void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv,
}
}
template
<
typename
ctype
>
template
<
typename
ctype
>
void
Param
ElemVisitor
<
1
,
ctype
,
BCAST_FULL
>::
host_init
(
const
TensorND
&
rv
,
void
Param
VisitorBase
<
1
,
ctype
,
BCAST_FULL
>::
host_init
(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
int
/*block_size*/
,
int
/*packed_size*/
)
{
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
0
]);
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
0
]);
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_ptr
=
rv
.
ptr
<
ctype
>
();
}
}
...
@@ -122,6 +126,53 @@ void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv,
...
@@ -122,6 +126,53 @@ void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv,
m_shape3
.
host_init
(
packed_size
*
grid_size
*
block_size
,
shape3
);
m_shape3
.
host_init
(
packed_size
*
grid_size
*
block_size
,
shape3
);
}
}
#define INST(ndim, ctype, brd) template class ParamVisitorBase<ndim, ctype, brd>
#define INST_FOR_CTYPE \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
INST(3, ct, BCAST_101); \
INST(2, ct, BCAST_10); \
INST(2, ct, BCAST_01); \
INST(1, ct, BCAST_FULL);
#define ndim_cb(_ndim) INST(_ndim, ct, BCAST_OTHER);
#define ct dt_byte
INST_FOR_CTYPE
#undef ct
#define ct dt_int32
INST_FOR_CTYPE
#undef ct
#define ct dt_float32
INST_FOR_CTYPE
#undef ct
#define ct dt_float16
INST_FOR_CTYPE
#undef ct
#define ct dt_bfloat16
INST_FOR_CTYPE
#undef ct
#define ct dt_int8
INST_FOR_CTYPE
#undef ct
#define ct dt_uint8
INST_FOR_CTYPE
#undef ct
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
#define ct dt_qint8
INST_FOR_CTYPE
#undef ct
#define ct dt_qint32
INST_FOR_CTYPE
#undef ct
#undef INST_FOR_CTYPE
#undef INST
#define INST(ndim, ctype, brd) template class ParamElemVisitor<ndim, ctype, brd>
#define INST(ndim, ctype, brd) template class ParamElemVisitor<ndim, ctype, brd>
#define INST_FOR_CTYPE \
#define INST_FOR_CTYPE \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
...
...
dnn/src/cuda/elemwise_helper.cuh
浏览文件 @
fc1ce273
...
@@ -142,6 +142,9 @@ INST(dt_qint32, int4);
...
@@ -142,6 +142,9 @@ INST(dt_qint32, int4);
* ptr()[offset(idx)]
* ptr()[offset(idx)]
*
*
*/
*/
template
<
int
ndim
,
typename
ctype
,
BcastType
brd_type
>
class
ParamVisitorBase
;
template
<
int
ndim
,
typename
ctype
,
BcastType
brd_type
>
template
<
int
ndim
,
typename
ctype
,
BcastType
brd_type
>
class
ParamElemVisitor
;
class
ParamElemVisitor
;
...
@@ -157,6 +160,7 @@ class ParamElemVisitor;
...
@@ -157,6 +160,7 @@ class ParamElemVisitor;
* ptr()[offset(idx)]
* ptr()[offset(idx)]
*
*
*/
*/
template
<
int
ndim
,
typename
ctype
,
BcastType
brd_type
>
template
<
int
ndim
,
typename
ctype
,
BcastType
brd_type
>
class
ParamVectVisitor
;
class
ParamVectVisitor
;
...
@@ -169,11 +173,9 @@ class ParamVectVisitor;
...
@@ -169,11 +173,9 @@ class ParamVectVisitor;
//! specialization for BCAST_OTHER
//! specialization for BCAST_OTHER
template
<
int
ndim
,
typename
ctype
>
template
<
int
ndim
,
typename
ctype
>
class
Param
ElemVisitor
<
ndim
,
ctype
,
BCAST_OTHER
>
{
class
Param
VisitorBase
<
ndim
,
ctype
,
BCAST_OTHER
>
{
protected:
protected:
ctype
*
__restrict
m_ptr
;
ctype
*
__restrict
m_ptr
;
private:
int
m_stride
[
ndim
];
int
m_stride
[
ndim
];
//! m_shape_highdim[i] = original_shape[i + 1]
//! m_shape_highdim[i] = original_shape[i + 1]
...
@@ -185,10 +187,9 @@ private:
...
@@ -185,10 +187,9 @@ private:
public:
public:
static
const
int
NDIM
=
ndim
;
static
const
int
NDIM
=
ndim
;
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
);
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
,
int
packed_size
);
#if MEGDNN_CC_CUDA
#if MEGDNN_CC_CUDA
devfunc
void
thread_init
(
uint32_t
)
{}
devfunc
void
thread_init
(
uint32_t
)
{}
...
@@ -211,6 +212,18 @@ public:
...
@@ -211,6 +212,18 @@ public:
#endif
#endif
};
};
template
<
int
ndim
,
typename
ctype
>
class
ParamElemVisitor
<
ndim
,
ctype
,
BCAST_OTHER
>
:
public
ParamVisitorBase
<
ndim
,
ctype
,
BCAST_OTHER
>
{
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
ParamVisitorBase
<
ndim
,
ctype
,
BCAST_OTHER
>::
host_init
(
rv
,
grid_size
,
block_size
,
packed_size
);
}
};
/*!
/*!
* \brief specialization for ndim == 3 and BCAST_101
* \brief specialization for ndim == 3 and BCAST_101
* (for dimshuffle 'x', 0, 'x')
* (for dimshuffle 'x', 0, 'x')
...
@@ -218,7 +231,7 @@ public:
...
@@ -218,7 +231,7 @@ public:
* visit: idx / m_shape2 % m_shape1
* visit: idx / m_shape2 % m_shape1
*/
*/
template
<
typename
ctype
>
template
<
typename
ctype
>
class
Param
ElemVisitor
<
3
,
ctype
,
BCAST_101
>
{
class
Param
VisitorBase
<
3
,
ctype
,
BCAST_101
>
{
StridedDivSeq2
m_shape12
;
StridedDivSeq2
m_shape12
;
int
m_stride1
;
int
m_stride1
;
...
@@ -227,9 +240,9 @@ protected:
...
@@ -227,9 +240,9 @@ protected:
public:
public:
static
const
int
NDIM
=
3
;
static
const
int
NDIM
=
3
;
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
);
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
,
int
packed_size
);
#if MEGDNN_CC_CUDA
#if MEGDNN_CC_CUDA
devfunc
void
thread_init
(
uint32_t
idx
)
{
m_shape12
.
device_init
(
idx
);
}
devfunc
void
thread_init
(
uint32_t
idx
)
{
m_shape12
.
device_init
(
idx
);
}
...
@@ -242,13 +255,25 @@ public:
...
@@ -242,13 +255,25 @@ public:
#endif
#endif
};
};
template
<
typename
ctype
>
class
ParamElemVisitor
<
3
,
ctype
,
BCAST_101
>
:
public
ParamVisitorBase
<
3
,
ctype
,
BCAST_101
>
{
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
ParamVisitorBase
<
3
,
ctype
,
BCAST_101
>::
host_init
(
rv
,
grid_size
,
block_size
,
packed_size
);
}
};
/*!
/*!
* \brief specialization for ndim == 2 and BCAST_10
* \brief specialization for ndim == 2 and BCAST_10
*
*
* visit: idx % m_shape1
* visit: idx % m_shape1
*/
*/
template
<
typename
ctype
>
template
<
typename
ctype
>
class
Param
ElemVisitor
<
2
,
ctype
,
BCAST_10
>
{
class
Param
VisitorBase
<
2
,
ctype
,
BCAST_10
>
{
StridedDivSeq
<
false
>
m_shape1
;
StridedDivSeq
<
false
>
m_shape1
;
int
m_stride1
;
int
m_stride1
;
...
@@ -257,9 +282,9 @@ protected:
...
@@ -257,9 +282,9 @@ protected:
public:
public:
static
const
int
NDIM
=
2
;
static
const
int
NDIM
=
2
;
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
);
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
,
int
packed_size
);
#if MEGDNN_CC_CUDA
#if MEGDNN_CC_CUDA
devfunc
void
thread_init
(
uint32_t
idx
)
{
m_shape1
.
device_init
(
idx
);
}
devfunc
void
thread_init
(
uint32_t
idx
)
{
m_shape1
.
device_init
(
idx
);
}
...
@@ -272,13 +297,25 @@ public:
...
@@ -272,13 +297,25 @@ public:
#endif
#endif
};
};
template
<
typename
ctype
>
class
ParamElemVisitor
<
2
,
ctype
,
BCAST_10
>
:
public
ParamVisitorBase
<
2
,
ctype
,
BCAST_10
>
{
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
ParamVisitorBase
<
2
,
ctype
,
BCAST_10
>::
host_init
(
rv
,
grid_size
,
block_size
,
packed_size
);
}
};
/*!
/*!
* \brief specialization for ndim == 2 and BCAST_01
* \brief specialization for ndim == 2 and BCAST_01
*
*
* visit: idx / shape1
* visit: idx / shape1
*/
*/
template
<
typename
ctype
>
template
<
typename
ctype
>
class
Param
ElemVisitor
<
2
,
ctype
,
BCAST_01
>
{
class
Param
VisitorBase
<
2
,
ctype
,
BCAST_01
>
{
StridedDivSeq
<
true
>
m_shape1
;
StridedDivSeq
<
true
>
m_shape1
;
int
m_stride0
;
int
m_stride0
;
...
@@ -287,9 +324,9 @@ protected:
...
@@ -287,9 +324,9 @@ protected:
public:
public:
static
const
int
NDIM
=
2
;
static
const
int
NDIM
=
2
;
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
);
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
,
int
packed_size
);
#if MEGDNN_CC_CUDA
#if MEGDNN_CC_CUDA
devfunc
void
thread_init
(
uint32_t
idx
)
{
m_shape1
.
device_init
(
idx
);
}
devfunc
void
thread_init
(
uint32_t
idx
)
{
m_shape1
.
device_init
(
idx
);
}
...
@@ -302,9 +339,21 @@ public:
...
@@ -302,9 +339,21 @@ public:
#endif
#endif
};
};
template
<
typename
ctype
>
class
ParamElemVisitor
<
2
,
ctype
,
BCAST_01
>
:
public
ParamVisitorBase
<
2
,
ctype
,
BCAST_01
>
{
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
ParamVisitorBase
<
2
,
ctype
,
BCAST_01
>::
host_init
(
rv
,
grid_size
,
block_size
,
packed_size
);
}
};
//! specialization for ndim == 1 and BCAST_FULL
//! specialization for ndim == 1 and BCAST_FULL
template
<
typename
ctype
>
template
<
typename
ctype
>
class
Param
ElemVisitor
<
1
,
ctype
,
BCAST_FULL
>
{
class
Param
VisitorBase
<
1
,
ctype
,
BCAST_FULL
>
{
protected:
protected:
ctype
*
__restrict
m_ptr
;
ctype
*
__restrict
m_ptr
;
...
@@ -312,7 +361,8 @@ public:
...
@@ -312,7 +361,8 @@ public:
static
const
int
NDIM
=
1
;
static
const
int
NDIM
=
1
;
PARAM_ELEM_VISITOR_COMMON_HOST
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
);
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
,
int
packed_size
);
#if MEGDNN_CC_CUDA
#if MEGDNN_CC_CUDA
devfunc
void
thread_init
(
uint32_t
)
{}
devfunc
void
thread_init
(
uint32_t
)
{}
...
@@ -328,6 +378,18 @@ public:
...
@@ -328,6 +378,18 @@ public:
#endif
#endif
};
};
template
<
typename
ctype
>
class
ParamElemVisitor
<
1
,
ctype
,
BCAST_FULL
>
:
public
ParamVisitorBase
<
1
,
ctype
,
BCAST_FULL
>
{
public:
PARAM_ELEM_VISITOR_COMMON_HOST
void
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
ParamVisitorBase
<
1
,
ctype
,
BCAST_FULL
>::
host_init
(
rv
,
grid_size
,
block_size
,
packed_size
);
}
};
#undef PARAM_ELEM_VISITOR_COMMON_DEV
#undef PARAM_ELEM_VISITOR_COMMON_DEV
#undef PARAM_ELEM_VISITOR_COMMON_HOST
#undef PARAM_ELEM_VISITOR_COMMON_HOST
...
@@ -340,17 +402,21 @@ public:
...
@@ -340,17 +402,21 @@ public:
#else
#else
#define DEVICE_WRAPPER(x)
#define DEVICE_WRAPPER(x)
#endif
#endif
#define INST_PARAM_VECT_VISITOR \
#define INST_PARAM_VECT_VISITOR \
template <int ndim, typename ctype> \
template <int ndim, typename ctype> \
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \
class ParamVectVisitor<ndim, ctype, _brdcast_mask> \
: public ParamElemVisitor<ndim, ctype, _brdcast_mask> { \
: public ParamVisitorBase<ndim, ctype, _brdcast_mask> { \
public: \
public: \
using Super = ParamElemVisitor<ndim, ctype, _brdcast_mask>; \
using Super = ParamVisitorBase<ndim, ctype, _brdcast_mask>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \
void host_init(const TensorND& rv, int grid_size, int block_size) { \
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \
ParamVisitorBase<ndim, ctype, _brdcast_mask>::host_init( \
}) \
rv, grid_size, block_size, packed_size); \
} \
DEVICE_WRAPPER(devfunc rwtype& at(uint32_t idx) { \
return *(rwtype*)(&Super::m_ptr[Super::offset(idx)]); \
}) \
};
};
#define _brdcast_mask BCAST_OTHER
#define _brdcast_mask BCAST_OTHER
INST_PARAM_VECT_VISITOR
;
INST_PARAM_VECT_VISITOR
;
...
@@ -367,11 +433,15 @@ INST_PARAM_VECT_VISITOR;
...
@@ -367,11 +433,15 @@ INST_PARAM_VECT_VISITOR;
#define INST_DT_IBYTE(ctype) \
#define INST_DT_IBYTE(ctype) \
template <int ndim> \
template <int ndim> \
class ParamVectVisitor<ndim, ctype, BCAST_FULL> \
class ParamVectVisitor<ndim, ctype, BCAST_FULL> \
: public Param
ElemVisitor
<ndim, ctype, BCAST_FULL> { \
: public Param
VisitorBase
<ndim, ctype, BCAST_FULL> { \
public: \
public: \
using Super = Param
ElemVisitor
<ndim, ctype, BCAST_FULL>; \
using Super = Param
VisitorBase
<ndim, ctype, BCAST_FULL>; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
using rwtype = typename VectTypeTrait<ctype>::vect_type; \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
static const int packed_size = sizeof(rwtype) / sizeof(ctype); \
void host_init(const TensorND& rv, int grid_size, int block_size) { \
ParamVisitorBase<ndim, ctype, BCAST_FULL>::host_init( \
rv, grid_size, block_size, packed_size); \
} \
DEVICE_WRAPPER(rwtype vect_scalar; \
DEVICE_WRAPPER(rwtype vect_scalar; \
devfunc rwtype & at(uint32_t
/* idx */
) { \
devfunc rwtype & at(uint32_t
/* idx */
) { \
ctype v = Super::m_ptr[0]; \
ctype v = Super::m_ptr[0]; \
...
...
dnn/test/cuda/elemwise.cpp
浏览文件 @
fc1ce273
...
@@ -269,6 +269,43 @@ TEST_F(CUDA, ELEMWISE_BFLOAT16) {
...
@@ -269,6 +269,43 @@ TEST_F(CUDA, ELEMWISE_BFLOAT16) {
#undef BUILD_TERNARY_COMPLATE_TEST_CASE
#undef BUILD_TERNARY_COMPLATE_TEST_CASE
}
}
TEST_F
(
CUDA
,
ELEMWISE_ADD_BCAST_10_INT8_INPLACE
)
{
constexpr
size_t
A
=
2
,
B
=
48
,
C0
=
14
,
C1
=
14
,
C
=
C0
*
C1
;
SyncedTensor
<
dt_int8
>
t0
(
handle_cuda
(),
{
TensorShape
{
A
,
B
,
C0
,
C1
},
dtype
::
Int8
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
B
,
C0
,
C1
},
dtype
::
Int8
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
A
,
B
,
C0
,
C1
},
dtype
::
Int8
()});
UniformIntRNG
rng
{
-
128
,
127
};
rng
.
gen
(
t0
.
tensornd_host
());
rng
.
gen
(
t1
.
tensornd_host
());
auto
p0
=
t0
.
ptr_host
(),
p1
=
t1
.
ptr_host
();
auto
p2
=
t2
.
ptr_mutable_host
();
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
for
(
size_t
k
=
0
;
k
<
C
;
++
k
)
{
auto
off0
=
j
*
C
+
k
;
auto
off1
=
i
*
B
*
C
+
j
*
C
+
k
;
p2
[
off1
]
=
p0
[
off1
]
+
p1
[
off0
];
}
}
}
auto
opr
=
handle_cuda
()
->
create_operator
<
ElemwiseForward
>
();
opr
->
param
().
mode
=
ElemwiseForward
::
Mode
::
ADD
;
opr
->
exec
({
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
()},
t0
.
tensornd_dev
());
auto
pt
=
t0
.
ptr_host
();
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
for
(
size_t
k
=
0
;
k
<
C
;
++
k
)
{
auto
off
=
i
*
B
*
C
+
j
*
C
+
k
;
ASSERT_EQ
(
pt
[
off
],
p2
[
off
]);
}
}
}
}
//! the memory of this test case is too large, sometimes will fail on tx1
//! the memory of this test case is too large, sometimes will fail on tx1
TEST_F
(
CUDA
,
ELEMWISE_BENCHMARK_DENSE
)
{
TEST_F
(
CUDA
,
ELEMWISE_BENCHMARK_DENSE
)
{
constexpr
size_t
A
=
256
*
1024
*
64
,
S0
=
16
,
S1
=
256
,
S2
=
64
,
S3
=
64
;
constexpr
size_t
A
=
256
*
1024
*
64
,
S0
=
16
,
S1
=
256
,
S2
=
64
,
S3
=
64
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录