提交 0894c8e9 编写于 作者: 叶剑武

Merge branch 'fix-df-bug' into 'master'

Bug: fix the data format bug of Shape and InferConv2dShape ops.

See merge request !1014
......@@ -36,7 +36,8 @@ class InferConv2dShapeOp : public Operation {
auto has_data_format =
Operation::GetOptionalArg<int>("has_data_format", 0);
const bool isNCHW = (has_data_format && D == DeviceType::CPU);
const bool isNCHW = (has_data_format &&
input->data_format() == DataFormat::NCHW);
Padding padding_type =
static_cast<Padding>(Operation::GetOptionalArg<int>(
......
......@@ -129,7 +129,8 @@ OpTestContext::OpTestContext(int num_threads,
device_map_[DeviceType::GPU] = make_unique<GPUDevice>(
gpu_context_->opencl_tuner(),
gpu_context_->opencl_cache_storage(),
GPUPriorityHint::PRIORITY_NORMAL);
GPUPriorityHint::PRIORITY_NORMAL,
GPUPerfHint::PERF_HIGH);
}
std::shared_ptr<GPUContext> OpTestContext::gpu_context() const {
......
......@@ -37,8 +37,8 @@ class ShapeOp : public Operation {
auto has_df = Operation::GetOptionalArg<int>(
"has_data_format", 0);
if (D == DeviceType::CPU &&
has_df && input->dim_size() == 4) {
if (has_df && input->data_format() == DataFormat::NCHW &&
input->dim_size() != 4) {
// transpose NCHW to NHWC for cpu runtime
output_data[0] = static_cast<int32_t>(input->dim(0));
output_data[1] = static_cast<int32_t>(input->dim(2));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册