Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
57bc3657
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
57bc3657
编写于
7月 16, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
style(dnn/cuda): format cuda elemwise code
GitOrigin-RevId: 246755ce20d708b5b35b48452996deeb63491513
上级
09eaa398
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
943 addition
and
969 deletion
+943
-969
dnn/src/cuda/elemwise_helper.cpp
dnn/src/cuda/elemwise_helper.cpp
+30
-31
dnn/src/cuda/elemwise_helper.cuh
dnn/src/cuda/elemwise_helper.cuh
+832
-851
dnn/test/cuda/elemwise.cpp
dnn/test/cuda/elemwise.cpp
+81
-87
未找到文件。
dnn/src/cuda/elemwise_helper.cpp
浏览文件 @
57bc3657
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/elemwise_helper.cuh"
...
...
@@ -21,7 +22,7 @@
#define _cb_check_ndim(n) megdnn::TensorShape::MAX_NDIM == n ||
static_assert
(
MEGDNN_FOREACH_TENSOR_NDIM
(
_cb_check_ndim
)
false
,
"bad foreach ndim"
);
"bad foreach ndim"
);
#undef _cb_check_ndim
namespace
megdnn
{
...
...
@@ -32,28 +33,30 @@ namespace elemwise_intl {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
template
<
int
ndim
,
typename
ctype
>
void
ParamElemVisitor
<
ndim
,
ctype
,
BCAST_OTHER
>::
host_init
(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
template
<
int
ndim
,
typename
ctype
>
void
ParamElemVisitor
<
ndim
,
ctype
,
BCAST_OTHER
>::
host_init
(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
megdnn_assert
(
rv
.
layout
.
ndim
&&
rv
.
layout
.
ndim
<=
ndim
);
m_ptr
=
rv
.
ptr
<
ctype
>
();
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
rv
.
layout
.
ndim
;
++
i
)
{
m_stride
[
i
]
=
rv
.
layout
.
stride
[
i
];
if
(
i
+
1
<
rv
.
layout
.
ndim
)
m_shape_highdim
[
i
]
=
rv
.
layout
.
shape
[
i
+
1
];
}
for
(
int
i
=
rv
.
layout
.
ndim
-
1
;
i
<
ndim
-
1
;
++
i
)
{
for
(
int
i
=
rv
.
layout
.
ndim
-
1
;
i
<
ndim
-
1
;
++
i
)
{
m_shape_highdim
[
i
]
=
1
;
}
for
(
int
i
=
rv
.
layout
.
ndim
;
i
<
ndim
;
++
i
)
{
for
(
int
i
=
rv
.
layout
.
ndim
;
i
<
ndim
;
++
i
)
{
m_stride
[
i
]
=
0
;
}
}
#pragma GCC diagnostic pop
template
<
typename
ctype
>
void
ParamElemVisitor
<
3
,
ctype
,
BCAST_101
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
void
ParamElemVisitor
<
3
,
ctype
,
BCAST_101
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
uint32_t
shape2
,
shape1
;
int
stride1
;
if
(
rv
.
layout
.
ndim
==
3
)
{
...
...
@@ -74,8 +77,8 @@ void ParamElemVisitor<3, ctype, BCAST_101>::host_init(
template
<
typename
ctype
>
void
ParamElemVisitor
<
2
,
ctype
,
BCAST_10
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
int
grid_size
,
int
block_size
)
{
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
0
]);
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_stride1
=
rv
.
layout
.
stride
[
1
];
...
...
@@ -85,8 +88,8 @@ void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv,
template
<
typename
ctype
>
void
ParamElemVisitor
<
2
,
ctype
,
BCAST_01
>::
host_init
(
const
TensorND
&
rv
,
int
grid_size
,
int
block_size
)
{
int
grid_size
,
int
block_size
)
{
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
1
]);
m_ptr
=
rv
.
ptr
<
ctype
>
();
m_stride0
=
rv
.
layout
.
stride
[
0
];
...
...
@@ -94,9 +97,10 @@ void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv,
rv
.
layout
.
shape
[
1
]);
}
template
<
typename
ctype
>
void
ParamElemVisitor
<
1
,
ctype
,
BCAST_FULL
>::
host_init
(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
template
<
typename
ctype
>
void
ParamElemVisitor
<
1
,
ctype
,
BCAST_FULL
>::
host_init
(
const
TensorND
&
rv
,
int
/*grid_size*/
,
int
/*block_size*/
)
{
megdnn_assert
(
rv
.
layout
.
ndim
==
NDIM
&&
!
rv
.
layout
.
stride
[
0
]);
m_ptr
=
rv
.
ptr
<
ctype
>
();
}
...
...
@@ -119,14 +123,13 @@ void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv,
}
#define INST(ndim, ctype, brd) template class ParamElemVisitor<ndim, ctype, brd>
#define INST_FOR_CTYPE \
#define INST_FOR_CTYPE
\
MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \
INST(3, ct, BCAST_101); \
INST(2, ct, BCAST_10); \
INST(2, ct, BCAST_01); \
INST(3, ct, BCAST_101);
\
INST(2, ct, BCAST_10);
\
INST(2, ct, BCAST_01);
\
INST(1, ct, BCAST_FULL);
#define ndim_cb(_ndim) INST(_ndim, ct, BCAST_OTHER);
#define ct dt_byte
...
...
@@ -175,11 +178,10 @@ INST(dt_qint8);
INST
(
dt_quint8
);
#undef dt_ibyte
}
// namespace elemwise_intl
}
// namespace elemwise_intl
void
elemwise_intl
::
get_launch_spec
(
const
void
*
kern
,
size_t
size
,
int
*
grid_size
,
int
*
block_size
)
{
void
elemwise_intl
::
get_launch_spec
(
const
void
*
kern
,
size_t
size
,
int
*
grid_size
,
int
*
block_size
)
{
safe_size_in_kern
(
size
);
auto
config
=
query_launch_config_for_kernel
(
kern
);
*
block_size
=
config
.
block_size
;
...
...
@@ -202,11 +204,8 @@ void elemwise_intl::get_launch_spec(
void
elemwise_intl
::
on_bad_ndim
(
int
ndim
)
{
megdnn_throw
(
ssprintf
(
"invalid ndim: %d"
,
ndim
));
MEGDNN_MARK_USED_VAR
(
ndim
);
}
}
// namespace cuda
}
// namespace megdnn
}
// namespace cuda
}
// namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
dnn/src/cuda/elemwise_helper.cuh
浏览文件 @
57bc3657
此差异已折叠。
点击以展开。
dnn/test/cuda/elemwise.cpp
浏览文件 @
57bc3657
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/elemwise.h"
...
...
@@ -26,66 +27,61 @@ using namespace test;
#define cudnn_check(e) megdnn_assert((e) == CUDNN_STATUS_SUCCESS)
namespace
{
__attribute__
((
unused
))
cudnnTensorDescriptor_t
make_cudnn_tensor_desc
(
const
TensorLayout
&
ly
)
{
megdnn_assert
(
ly
.
ndim
&&
ly
.
ndim
<=
4
&&
ly
.
is_contiguous
());
int
dim
[
4
]
=
{
1
,
1
,
1
,
1
},
stride
[
4
]
=
{
1
,
1
,
1
,
1
};
for
(
size_t
i
=
0
;
i
<
ly
.
ndim
;
++
i
)
{
dim
[
i
]
=
ly
.
shape
[
i
];
stride
[
i
]
=
ly
.
stride
[
i
];
}
cudnnTensorDescriptor_t
ret
;
cudnn_check
(
cudnnCreateTensorDescriptor
(
&
ret
));
// cudnn requires tensors to be at-least 4D
cudnn_check
(
cudnnSetTensor4dDescriptorEx
(
ret
,
CUDNN_DATA_FLOAT
,
dim
[
0
],
dim
[
1
],
dim
[
2
],
dim
[
3
],
stride
[
0
],
stride
[
1
],
stride
[
2
],
stride
[
3
]));
return
ret
;
__attribute__
((
unused
))
cudnnTensorDescriptor_t
make_cudnn_tensor_desc
(
const
TensorLayout
&
ly
)
{
megdnn_assert
(
ly
.
ndim
&&
ly
.
ndim
<=
4
&&
ly
.
is_contiguous
());
int
dim
[
4
]
=
{
1
,
1
,
1
,
1
},
stride
[
4
]
=
{
1
,
1
,
1
,
1
};
for
(
size_t
i
=
0
;
i
<
ly
.
ndim
;
++
i
)
{
dim
[
i
]
=
ly
.
shape
[
i
];
stride
[
i
]
=
ly
.
stride
[
i
];
}
cudnnTensorDescriptor_t
ret
;
cudnn_check
(
cudnnCreateTensorDescriptor
(
&
ret
));
// cudnn requires tensors to be at-least 4D
cudnn_check
(
cudnnSetTensor4dDescriptorEx
(
ret
,
CUDNN_DATA_FLOAT
,
dim
[
0
],
dim
[
1
],
dim
[
2
],
dim
[
3
],
stride
[
0
],
stride
[
1
],
stride
[
2
],
stride
[
3
]));
return
ret
;
}
void
run_tensor_add
(
Handle
*
handle_cuda
,
const
TensorND
&
a
,
const
TensorND
&
b
,
const
TensorND
&
c
)
{
void
run_tensor_add
(
Handle
*
handle_cuda
,
const
TensorND
&
a
,
const
TensorND
&
b
,
const
TensorND
&
c
)
{
#if 1
cudnnHandle_t
cudnn_handle
;
cudnn_check
(
cudnnCreate
(
&
cudnn_handle
));
cuda_check
(
cudaDeviceSynchronize
());
cuda_check
(
cudaMemcpy
(
c
.
raw_ptr
,
a
.
raw_ptr
,
a
.
layout
.
span
().
dist_byte
(),
cudaMemcpyDeviceToDevice
));
auto
bdesc
=
make_cudnn_tensor_desc
(
b
.
layout
),
cdesc
=
make_cudnn_tensor_desc
(
c
.
layout
);
float
alpha
=
1
,
beta
=
1
;
cudaProfilerStart
();
cudnn_check
(
cudnnAddTensor
(
cudnn_handle
,
&
alpha
,
bdesc
,
b
.
raw_ptr
,
&
beta
,
cdesc
,
c
.
raw_ptr
));
cudaProfilerStop
();
cudnn_check
(
cudnnDestroyTensorDescriptor
(
cdesc
));
cudnn_check
(
cudnnDestroyTensorDescriptor
(
bdesc
));
cudnn_check
(
cudnnDestroy
(
cudnn_handle
));
cuda_check
(
cudaMemset
(
c
.
raw_ptr
,
0
,
c
.
layout
.
span
().
dist_byte
()));
cuda_check
(
cudaDeviceSynchronize
());
cudnnHandle_t
cudnn_handle
;
cudnn_check
(
cudnnCreate
(
&
cudnn_handle
));
cuda_check
(
cudaDeviceSynchronize
());
cuda_check
(
cudaMemcpy
(
c
.
raw_ptr
,
a
.
raw_ptr
,
a
.
layout
.
span
().
dist_byte
(),
cudaMemcpyDeviceToDevice
));
auto
bdesc
=
make_cudnn_tensor_desc
(
b
.
layout
),
cdesc
=
make_cudnn_tensor_desc
(
c
.
layout
);
float
alpha
=
1
,
beta
=
1
;
cudaProfilerStart
();
cudnn_check
(
cudnnAddTensor
(
cudnn_handle
,
&
alpha
,
bdesc
,
b
.
raw_ptr
,
&
beta
,
cdesc
,
c
.
raw_ptr
));
cudaProfilerStop
();
cudnn_check
(
cudnnDestroyTensorDescriptor
(
cdesc
));
cudnn_check
(
cudnnDestroyTensorDescriptor
(
bdesc
));
cudnn_check
(
cudnnDestroy
(
cudnn_handle
));
cuda_check
(
cudaMemset
(
c
.
raw_ptr
,
0
,
c
.
layout
.
span
().
dist_byte
()));
cuda_check
(
cudaDeviceSynchronize
());
#endif
auto
opr
=
handle_cuda
->
create_operator
<
ElemwiseForward
>
();
opr
->
param
().
mode
=
ElemwiseForward
::
Mode
::
ADD
;
cudaProfilerStart
();
opr
->
exec
({
a
,
b
},
c
);
cudaProfilerStop
();
}
auto
opr
=
handle_cuda
->
create_operator
<
ElemwiseForward
>
();
opr
->
param
().
mode
=
ElemwiseForward
::
Mode
::
ADD
;
cudaProfilerStart
();
opr
->
exec
({
a
,
b
},
c
);
cudaProfilerStop
();
}
}
// anonymous namespace
}
// anonymous namespace
template
<
typename
tag
>
class
CUDA_ELEMWISE
:
public
CUDA
{
};
template
<
typename
tag
>
class
CUDA_ELEMWISE
:
public
CUDA
{};
TYPED_TEST_CASE
(
CUDA_ELEMWISE
,
elemwise
::
test_types
);
TYPED_TEST
(
CUDA_ELEMWISE
,
run
)
{
elemwise
::
run_test
<
TypeParam
>
(
this
->
handle_cuda
());
...
...
@@ -275,18 +271,17 @@ TEST_F(CUDA, ELEMWISE_BFLOAT16) {
//! the memory of this test case is too large, sometimes will fail on tx1
TEST_F
(
CUDA
,
ELEMWISE_BENCHMARK_DENSE
)
{
constexpr
size_t
A
=
256
*
1024
*
64
,
S0
=
16
,
S1
=
256
,
S2
=
64
,
S3
=
64
;
constexpr
size_t
A
=
256
*
1024
*
64
,
S0
=
16
,
S1
=
256
,
S2
=
64
,
S3
=
64
;
static_assert
(
A
==
S0
*
S1
*
S2
*
S3
,
"bad value"
);
SyncedTensor
<>
t0
(
handle_cuda
(),
{
TensorShape
{
S0
,
S1
,
S2
,
S3
},
dtype
::
Float32
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
S0
,
S1
,
S2
,
S3
},
dtype
::
Float32
()});
SyncedTensor
<>
t0
(
handle_cuda
(),
{
TensorShape
{
S0
,
S1
,
S2
,
S3
},
dtype
::
Float32
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
S0
,
S1
,
S2
,
S3
},
dtype
::
Float32
()});
UniformFloatRNG
rng
{
-
2.
f
,
2.
f
};
rng
.
gen
(
t0
.
tensornd_host
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
());
auto
p0
=
t0
.
ptr_host
(),
p1
=
t1
.
ptr_host
();
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
ASSERT_EQ
(
p0
[
i
]
+
p0
[
i
],
p1
[
i
])
<<
"at index "
<<
i
<<
"/"
<<
A
;
}
}
...
...
@@ -294,19 +289,19 @@ TEST_F(CUDA, ELEMWISE_BENCHMARK_DENSE) {
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
CUDA
,
ELEMWISE_BENCHMARK_BCAST_101
)
{
constexpr
size_t
A
=
511
,
B
=
509
,
C0
=
23
,
C1
=
23
,
C
=
C0
*
C1
;
SyncedTensor
<>
t0
(
handle_cuda
(),
{
TensorShape
{
A
,
B
,
C0
,
C1
},
dtype
::
Float32
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
B
,
1
,
1
},
dtype
::
Float32
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
A
,
B
,
C0
,
C1
},
dtype
::
Float32
()});
SyncedTensor
<>
t0
(
handle_cuda
(),
{
TensorShape
{
A
,
B
,
C0
,
C1
},
dtype
::
Float32
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
B
,
1
,
1
},
dtype
::
Float32
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
A
,
B
,
C0
,
C1
},
dtype
::
Float32
()});
UniformFloatRNG
rng
{
-
2.
f
,
2.
f
};
rng
.
gen
(
t0
.
tensornd_host
());
rng
.
gen
(
t1
.
tensornd_host
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
(),
t2
.
tensornd_dev
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
(),
t2
.
tensornd_dev
());
auto
p0
=
t0
.
ptr_host
(),
p1
=
t1
.
ptr_host
(),
p2
=
t2
.
ptr_host
();
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
for
(
size_t
k
=
0
;
k
<
C
;
++
k
)
{
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
for
(
size_t
k
=
0
;
k
<
C
;
++
k
)
{
auto
off
=
i
*
B
*
C
+
j
*
C
+
k
;
ASSERT_EQ
(
p0
[
off
]
+
p1
[
j
],
p2
[
off
]);
}
...
...
@@ -317,16 +312,16 @@ TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_101) {
TEST_F
(
CUDA
,
ELEMWISE_BENCHMARK_BCAST_10
)
{
constexpr
size_t
A
=
11583
,
B
=
11587
;
SyncedTensor
<>
t0
(
handle_cuda
(),
{
TensorShape
{
A
,
B
},
dtype
::
Float32
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
B
},
dtype
::
Float32
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
A
,
B
},
dtype
::
Float32
()});
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
B
},
dtype
::
Float32
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
A
,
B
},
dtype
::
Float32
()});
UniformFloatRNG
rng
{
-
2.
f
,
2.
f
};
rng
.
gen
(
t0
.
tensornd_host
());
rng
.
gen
(
t1
.
tensornd_host
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
(),
t2
.
tensornd_dev
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
(),
t2
.
tensornd_dev
());
auto
p0
=
t0
.
ptr_host
(),
p1
=
t1
.
ptr_host
(),
p2
=
t2
.
ptr_host
();
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
auto
off
=
i
*
B
+
j
;
ASSERT_EQ
(
p0
[
off
]
+
p1
[
j
],
p2
[
off
]);
}
...
...
@@ -336,16 +331,16 @@ TEST_F(CUDA, ELEMWISE_BENCHMARK_BCAST_10) {
TEST_F
(
CUDA
,
ELEMWISE_BENCHMARK_BCAST_01
)
{
constexpr
size_t
A
=
11583
,
B
=
11587
;
SyncedTensor
<>
t0
(
handle_cuda
(),
{
TensorShape
{
1
,
A
,
B
},
dtype
::
Float32
()}),
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
A
,
1
},
dtype
::
Float32
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
1
,
A
,
B
},
dtype
::
Float32
()});
t1
(
handle_cuda
(),
{
TensorShape
{
1
,
A
,
1
},
dtype
::
Float32
()}),
t2
(
handle_cuda
(),
{
TensorShape
{
1
,
A
,
B
},
dtype
::
Float32
()});
UniformFloatRNG
rng
{
-
2.
f
,
2.
f
};
rng
.
gen
(
t0
.
tensornd_host
());
rng
.
gen
(
t1
.
tensornd_host
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
(),
t2
.
tensornd_dev
());
run_tensor_add
(
handle_cuda
(),
t0
.
tensornd_dev
(),
t1
.
tensornd_dev
(),
t2
.
tensornd_dev
());
auto
p0
=
t0
.
ptr_host
(),
p1
=
t1
.
ptr_host
(),
p2
=
t2
.
ptr_host
();
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
for
(
size_t
i
=
0
;
i
<
A
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
B
;
++
j
)
{
auto
off
=
i
*
B
+
j
;
ASSERT_EQ
(
p0
[
off
]
+
p1
[
i
],
p2
[
off
]);
}
...
...
@@ -361,8 +356,9 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_IBYTE) {
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
set_dtype
(
0
,
dtype
::
Int8
())
.
set_dtype
(
1
,
dtype
::
Int8
());
auto
time
=
bencher
.
execs
({{
N
*
C
*
H
*
W
+
1
},
{
N
*
C
*
H
*
W
+
1
},
{}})
/
nr_times
;
auto
time
=
bencher
.
execs
({{
N
*
C
*
H
*
W
+
1
},
{
N
*
C
*
H
*
W
+
1
},
{}})
/
nr_times
;
printf
(
"time = %.2fms, bandwidth = %.2fGB/s
\n
"
,
time
,
(
3.0
*
(
N
*
C
*
H
*
W
+
1
))
/
(
time
*
1e6
));
time
=
bencher
.
execs
({{
N
,
C
/
4
,
H
,
W
,
4
},
{
N
,
C
/
4
,
H
,
W
,
4
},
{}})
/
...
...
@@ -395,7 +391,6 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_IBYTE) {
nr_times
;
printf
(
"time = %.2fms, bandwidth = %.2fGB/s
\n
"
,
time
,
(
C
+
2.0
*
N
*
C
*
H
*
W
)
/
(
time
*
1e6
));
};
run_bench
(
256
,
256
,
56
,
56
);
}
...
...
@@ -428,4 +423,3 @@ TEST_F(CUDA, BENCHMARK_ELEMWISE_MIN_MAX) {
#endif
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录