提交 115c4592 编写于 作者: M Megvii Engine Team

fix(dnn/opencl): fix opencl elemwise tuning issue

GitOrigin-RevId: 317640547d262cbfec90f79786c60a872253f0b8
上级 b9cbc101
......@@ -158,7 +158,7 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) {
opr->param() = m_param;
auto user_layouts = layouts;
m_proxy->deduce_layout(opr, layouts);
for (size_t i = 0; i < layouts.size(); ++i)
for (size_t i = 0; i < layouts.size(); ++i) {
if (user_layouts[i].ndim > 0) {
auto run = [&]() {
ASSERT_TRUE(layouts[i].eq_shape(user_layouts[i]))
......@@ -169,13 +169,14 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) {
};
run();
}
}
auto allocate = [&layouts](Handle* handle) {
TensorNDArray tensors(layouts.size());
auto trans_func = [handle](const TensorLayout& layout) {
auto span = layout.span();
TensorND res;
res.reset_ptr(
static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte())) +
static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte())) -
span.low_byte);
res.layout = layout;
return res;
......@@ -244,7 +245,9 @@ float BenchmarkerBase<Opr, T>::exec(TensorLayoutArray layouts) {
}
auto free = [](Handle* handle, TensorNDArray& tensors) {
std::for_each(tensors.begin(), tensors.end(), [handle](const TensorND& tensor) {
megdnn_free(handle, tensor.raw_ptr());
megdnn_free(
handle, static_cast<dt_byte*>(tensor.raw_ptr()) +
tensor.layout.span().low_byte);
});
};
free(m_handle, tensors_cur);
......@@ -283,7 +286,7 @@ float BenchmarkerBase<Opr, T>::exect(const TensorValueArray& testcase_in) {
auto span = layout.span();
TensorND res;
res.reset_ptr(
static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte())) +
static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte())) -
span.low_byte);
res.layout = layout;
return res;
......@@ -341,7 +344,9 @@ float BenchmarkerBase<Opr, T>::exect(const TensorValueArray& testcase_in) {
}
auto free = [](Handle* handle, TensorNDArray& tensors) {
std::for_each(tensors.begin(), tensors.end(), [handle](const TensorND& tensor) {
megdnn_free(handle, tensor.raw_ptr());
megdnn_free(
handle, static_cast<dt_byte*>(tensor.raw_ptr()) +
tensor.layout.span().low_byte);
});
};
free(m_handle, tensors_cur);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册