未验证 提交 9a6926f5 编写于 作者: C cc 提交者: GitHub

[cherry-pick] Add mkldnn interpolate op, support manual enable mkldnn interpolate op (#30083)

上级 c06350c9
......@@ -108,6 +108,7 @@ if(WITH_MKLDNN)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/interpolate_mkldnn_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
class OpDesc;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class Graph;
void InterpolateMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
VLOG(3) << "Do not handle interpolate_mkldnn_pass";
return;
}
VLOG(4) << "Handle interpolate_mkldnn_pass";
Init("interpolate_mkldnn_pass", graph);
int found_count = 0;
const std::vector<std::string> interpolate_op_types = {
"bilinear_interp", "nearest_interp", "trilinear_interp", "bicubic_interp",
"linear_interp"};
for (const Node* node : graph->Nodes()) {
if (node->IsOp() &&
std::find(interpolate_op_types.begin(), interpolate_op_types.end(),
node->Name()) != interpolate_op_types.end()) {
auto* op_desc = node->Op();
op_desc->SetAttr("use_mkldnn", true);
++found_count;
}
}
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(interpolate_mkldnn_pass,
paddle::framework::ir::InterpolateMKLDNNPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Change the interpolate op to run MKLDNN.
*/
class Graph;
class InterpolateMKLDNNPass : public FusePassBase {
public:
virtual ~InterpolateMKLDNNPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/placement_pass_base.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
......@@ -33,7 +34,7 @@ void PlacementPassBase::ApplyImpl(ir::Graph* graph) const {
auto* op = n->Op();
if ((op->HasAttr(attr_name) || op->HasProtoAttr(attr_name)) &&
IsSupport(op->Type())) {
if (op_types_list.empty()) {
if (op_types_list.empty() && IsDefaultOpTypes(op->Type())) {
op->SetAttr(attr_name, true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
......@@ -59,7 +60,30 @@ bool PlacementPassBase::IsSupport(const std::string& op_type) const {
}
}
} else if (GetAttrName() == "use_mkldnn") {
// This ops have use_mkldnn attr, but not support for now.
const std::vector<std::string> op_types = {
"trilinear_interp", "bicubic_interp", "linear_interp"};
return std::find(op_types.begin(), op_types.end(), op_type) ==
op_types.end();
}
return false;
}
bool PlacementPassBase::IsDefaultOpTypes(const std::string& op_type) const {
if (GetAttrName() == "use_cudnn") {
return true;
} else if (GetAttrName() == "use_mkldnn") {
// For interpolate ops, there's a little difference between Paddle and
// MKLDNN.
// If run MKLDNN interpolate ops, manual set AnalysisConfig and apply
// the corresponding pass.
const std::vector<std::string> not_default_op_types = {
"bilinear_interp", "nearest_interp", "trilinear_interp",
"bicubic_interp", "linear_interp"};
bool is_interpolate_op =
std::find(not_default_op_types.begin(), not_default_op_types.end(),
op_type) != not_default_op_types.end();
return !is_interpolate_op;
}
return false;
}
......
......@@ -38,6 +38,7 @@ class PlacementPassBase : public Pass {
private:
bool IsSupport(const std::string& op_type) const;
bool IsDefaultOpTypes(const std::string& op_type) const;
#if PADDLE_WITH_TESTING
friend class PlacementPassTest;
......
......@@ -14,6 +14,9 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -302,7 +305,6 @@ class InterpolateOp : public framework::OperatorWithKernel {
platform::errors::Unimplemented(
"Input(X) dimension must be 3, 4 or 5, but got dimension = %d .",
dim_x.size()));
if (dim_x.size() == 3) {
// shape check for 1D interpolate for input tensor shape NCHW
Interpolate1DInferShapeCheck(ctx);
......@@ -318,13 +320,42 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_layout");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), dl);
}
}
#endif
if (var_name == "SizeTensor" || var_name == "Scale") {
return expected_kernel_type;
}
......@@ -394,6 +425,9 @@ class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
"can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , "
"can be \'1\' for src_idx = scale*dst_index .")
.SetDefault(1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using framework::DataLayout;
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
using dnnl::stream;
using dnnl::resampling_forward;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
template <typename T = float>
class InterpolateMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::resampling_forward> {
public:
InterpolateMKLDNNHandler(const dnnl::algorithm algo,
const paddle::platform::MKLDNNDeviceContext& dev_ctx,
const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, Tensor* z,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::resampling_forward>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
const auto src_x_tz = framework::vectorize(x->dims());
const auto dst_tz = framework::vectorize(z->dims());
const auto src_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, algo, src_md, dst_md);
}
}
};
template <typename T = float>
class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int> ComputeOutputShape(
const framework::ExecutionContext& ctx) const {
const auto* x = ctx.Input<Tensor>("X");
auto in_dims = x->dims();
const bool is_channel_last = false; // In mkldnn kernel, always use NCHW
framework::DDim in_dhw_dims;
if (is_channel_last) { // NDHWC, NHWC, NWC
in_dhw_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { // NCDHW, NCHW, NCW
in_dhw_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
}
std::vector<int> out_dims;
if (in_dhw_dims.size() == 1) {
out_dims.push_back(ctx.Attr<int>("out_w"));
} else if (in_dhw_dims.size() == 2) {
out_dims.push_back(ctx.Attr<int>("out_h"));
out_dims.push_back(ctx.Attr<int>("out_w"));
} else if (in_dhw_dims.size() == 3) {
out_dims.push_back(ctx.Attr<int>("out_d"));
out_dims.push_back(ctx.Attr<int>("out_h"));
out_dims.push_back(ctx.Attr<int>("out_w"));
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
auto out_size = ctx.Input<Tensor>("OutSize");
if (list_new_size_tensor.size() > 0) {
auto new_size = get_new_shape(list_new_size_tensor);
if (new_size.size() == out_dims.size()) {
out_dims = new_size;
}
} else if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor<int>(out_size);
if (out_size_data.size() == out_dims.size()) {
out_dims = out_size_data;
}
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
std::vector<int64_t> in_dhw_vec = framework::vectorize(in_dhw_dims);
std::transform(
in_dhw_vec.begin(), in_dhw_vec.end(), out_dims.begin(),
[&](int64_t i) -> int { return static_cast<int>(i * scale); });
}
}
PADDLE_ENFORCE_GT(std::all_of(out_dims.begin(), out_dims.end(),
[](int i) { return i > 0; }),
0, platform::errors::InvalidArgument(
"out_d, out_h, out_w of Op(interpolate) "
"should be greater than 0."));
out_dims.insert(out_dims.begin(), in_dims[0]);
if (is_channel_last) {
out_dims.push_back(in_dims[in_dims.size() - 1]);
} else {
out_dims.insert(out_dims.begin() + 1, in_dims[1]);
}
return out_dims;
}
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X");
std::vector<float> scale_prior;
auto* z = ctx.Output<Tensor>("Out");
auto interp_method = ctx.Attr<std::string>("interp_method");
dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;
auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = framework::make_ddim(out_dims_vec);
z->mutable_data<T>(dim_out, ctx.GetPlace());
InterpolateMKLDNNHandler<T> handler(algo, dev_ctx, mkldnn_engine,
ctx.GetPlace(), x, z,
ctx.OutputName("Out"));
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(z);
auto resampling_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
mkldnn::stream astream(mkldnn_engine);
resampling_prim->execute(astream, args);
astream.wait();
z->set_layout(DataLayout::kMKLDNN);
z->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import math
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
def bilinear_interp_mkldnn_np(input,
out_h,
out_w,
out_size=None,
actual_shape=None,
data_layout='NCHW'):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
batch_size, channel, in_h, in_w = input.shape
out = np.zeros((batch_size, channel, out_h, out_w))
for oh in range(out_h):
h0 = int(math.floor((oh + 0.5) * in_h / out_h - 0.5))
h1 = int(math.ceil((oh + 0.5) * in_h / out_h - 0.5))
h0 = max(h0, 0)
h1 = min(h1, in_h - 1)
Wh = (oh + 0.5) * in_h / out_h - 0.5 - h0
for ow in range(out_w):
w0 = int(math.floor((ow + 0.5) * in_w / out_w - 0.5))
w1 = int(math.ceil((ow + 0.5) * in_w / out_w - 0.5))
w0 = max(w0, 0)
w1 = min(w1, in_w - 1)
Ww = (ow + 0.5) * in_w / out_w - 0.5 - w0
input_h0_w0 = input[:, :, h0, w0]
input_h1_w0 = input[:, :, h1, w0]
input_h0_w1 = input[:, :, h0, w1]
input_h1_w1 = input[:, :, h1, w1]
out[:, :, oh, ow] = input_h0_w0 * (1 - Wh) * (
1 - Ww) + input_h1_w0 * Wh * (1 - Ww) + input_h0_w1 * (
1 - Wh) * Ww + input_h1_w1 * Wh * Ww
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(input.dtype)
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
class TestBilinearInterpMKLDNNOp(OpTest):
def init_test_case(self):
pass
def setUp(self):
self.op_type = "bilinear_interp"
self.interp_method = 'bilinear'
self._cpu_only = True
self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
self.scale = 2.0
self.out_size = None
self.actual_shape = None
self.init_test_case()
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale > 0:
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_mkldnn_np(input_np, out_h, out_w,
self.out_size, self.actual_shape,
self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'interp_method': self.interp_method,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'data_layout': self.data_layout,
'use_mkldnn': self.use_mkldnn
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [3, 2, 32, 16]
self.out_h = 27
self.out_w = 49
self.scale = 2.0
self.data_layout = 'NHWC'
class TestBilinearNeighborInterpMKLDNNCase2(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 1.
class TestBilinearNeighborInterpDataLayout(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [2, 4, 4, 5]
self.out_h = 6
self.out_w = 7
self.scale = 0.
self.data_layout = "NHWC"
class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 128
self.scale = 0.
class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([13, 13]).astype("int32")
class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
def nearest_neighbor_interp_mkldnn_np(X,
out_h,
out_w,
out_size=None,
actual_shape=None,
data_layout='NCHW'):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
n, c, in_h, in_w = X.shape
fh = fw = 0.0
if (out_h > 1):
fh = out_h * 1.0 / in_h
if (out_w > 1):
fw = out_w * 1.0 / in_w
out = np.zeros((n, c, out_h, out_w))
for oh in range(out_h):
ih = int(round((oh + 0.5) / fh - 0.5))
for ow in range(out_w):
iw = int(round((ow + 0.5) / fw - 0.5))
out[:, :, oh, ow] = X[:, :, ih, iw]
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(X.dtype)
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
class TestNearestInterpMKLDNNOp(OpTest):
def init_test_case(self):
pass
def setUp(self):
self.op_type = "nearest_interp"
self.interp_method = 'nearest'
self._cpu_only = True
self.use_mkldnn = True
self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW'
# priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1
self.out_w = 1
self.scale = 2.0
self.out_size = None
self.actual_shape = None
self.init_test_case()
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale > 0:
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = nearest_neighbor_interp_mkldnn_np(
input_np, out_h, out_w, self.out_size, self.actual_shape,
self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'interp_method': self.interp_method,
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'data_layout': self.data_layout,
'use_mkldnn': self.use_mkldnn
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output(check_dygraph=False)
class TestNearestInterpOpMKLDNNNHWC(TestNearestInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [3, 2, 32, 16]
self.out_h = 27
self.out_w = 49
self.scale = 2.0
self.data_layout = 'NHWC'
class TestNearestNeighborInterpMKLDNNCase2(TestNearestInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 1.
class TestNearestNeighborInterpCase3(TestNearestInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 128
self.scale = 0.
class TestNearestNeighborInterpCase4(TestNearestInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([65, 129]).astype("int32")
class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp):
def init_test_case(self):
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
if __name__ == "__main__":
unittest.main()
......@@ -596,6 +596,8 @@ STATIC_MODE_TESTING_LIST = [
'test_elementwise_mul_bf16_mkldnn_op',
'test_fc_mkldnn_op',
'test_fc_bf16_mkldnn_op',
'test_nearest_interp_mkldnn_op',
'test_bilinear_interp_mkldnn_op',
'test_fusion_gru_int8_mkldnn_op',
'test_fusion_gru_mkldnn_op',
'test_gaussian_random_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册