未验证 提交 3b9f040d 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add nearest_interp_v2 and nearest_interp_v2_grad, test=develop (#34769)

上级 e4e8cc9b
......@@ -58,6 +58,12 @@ inline std::vector<T> get_new_data_from_tensor(const Tensor* new_data_tensor) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
#endif
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the Licnse. */
#include "paddle/fluid/operators/interpolate_v2_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename DeviceContext, typename T>
class InterpolateV2NPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
PADDLE_ENFORCE_EQ(input_dims.size(), 4UL,
platform::errors::External(
"NPU Interpolate Kernel only support 4-D Tensor."));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w);
PADDLE_ENFORCE_EQ(
input->layout(), data_layout,
platform::errors::InvalidArgument(
"Interpolate OP's input tensor layout should equal to attr "
"data_layout, but got tensor layout <%s>, attr layout <%s>",
framework::DataLayoutToString(input->layout()), data_layout_str));
PADDLE_ENFORCE_EQ(
output->layout(), data_layout,
platform::errors::InvalidArgument(
"Interpolate OP's output tensor layout should equal to attr "
"data_layout, but got tensor layout <%s>, attr layout <%s>",
framework::DataLayoutToString(output->layout()), data_layout_str));
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
// To-do(qili93): need to support align_corners = true case, try ReSizeD
PADDLE_ENFORCE_EQ(
align_corners, false,
platform::errors::InvalidArgument(
"NPU Interpolate Kernel has diff when align_corners is true."));
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1;
float scale_w = -1;
// Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_shape_tensor.size() > 0) {
std::vector<int32_t> output_h(1);
std::vector<int32_t> output_w(1);
auto dev_ctx =
platform::DeviceContextPool::Instance().Get(ctx.GetPlace());
framework::TensorToVector(*list_new_shape_tensor[0], *dev_ctx, &output_h);
framework::TensorToVector(*list_new_shape_tensor[1], *dev_ctx, &output_w);
out_h = output_h[0];
out_w = output_w[0];
} else if (ctx.HasInput("OutSize")) {
auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_h > 0. && scale_w > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
}
PADDLE_ENFORCE_GT(out_h, 0,
platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0,
platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner;
// To-do(qili93): need to support bilineare, try ResizeD
if ("nearest" == interp_method) {
runner.SetType("ResizeNearestNeighborV2")
.AddInput(*input)
.AddInput(std::vector<int32_t>{out_h, out_w})
.AddOutput(*output)
.AddAttr("align_corners", align_corners)
.AddAttr("half_pixel_centers", false);
}
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class InterpolateV2NPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
PADDLE_ENFORCE_EQ(
input->layout(), data_layout,
platform::errors::InvalidArgument(
"Interpolate OP's input tensor layout should equal to attr "
"data_layout, but got tensor layout <%s>, attr layout <%s>",
framework::DataLayoutToString(input->layout()), data_layout_str));
PADDLE_ENFORCE_EQ(output_grad->layout(), data_layout,
platform::errors::InvalidArgument(
"Interpolate OP's output_grad tensor layout should "
"equal to attr data_layout, but got tensor layout is "
"<%s>, and attr layout is <%s>",
framework::DataLayoutToString(output_grad->layout()),
data_layout_str));
PADDLE_ENFORCE_EQ(input_grad->layout(), data_layout,
platform::errors::InvalidArgument(
"Interpolate OP's input_grad tensor layout should "
"equal to attr data_layout, but got tensor layout is "
"<%s>, and attr layout is <%s>",
framework::DataLayoutToString(input_grad->layout()),
data_layout_str));
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
// To-do(qili93): need to support align_corners = true case, try ReSizeD
PADDLE_ENFORCE_EQ(
align_corners, false,
platform::errors::InvalidArgument(
"NPU Interpolate Kernel has diff when align_corners is true."));
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale_h = -1;
float scale_w = -1;
// Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
std::vector<int32_t> output_h(1);
std::vector<int32_t> output_w(1);
auto dev_ctx =
platform::DeviceContextPool::Instance().Get(ctx.GetPlace());
framework::TensorToVector(*list_new_size_tensor[0], *dev_ctx, &output_h);
framework::TensorToVector(*list_new_size_tensor[1], *dev_ctx, &output_w);
out_h = output_h[0];
out_w = output_w[0];
} else if (ctx.HasInput("OutSize")) {
auto out_size = ctx.Input<Tensor>("OutSize");
auto out_size_data = get_new_data_from_tensor<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
} else {
auto scale_tensor = ctx.Input<Tensor>("Scale");
auto scale = ctx.Attr<std::vector<float>>("scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
if (scale_data.size() > 1) {
scale_h = scale_data[0];
scale_w = scale_data[1];
} else {
scale_w = scale_data[0];
scale_h = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in input 'Scale' Tensor of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
} else {
if (scale.size() > 1) {
scale_h = scale[0];
scale_w = scale[1];
PADDLE_ENFORCE_EQ(
scale_w > 0, true,
platform::errors::InvalidArgument(
"The scale_w in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_w));
PADDLE_ENFORCE_EQ(
scale_h > 0, true,
platform::errors::InvalidArgument(
"The scale_h in Attr(scale) of Operator(interpolate) "
"should be greater than 0, but received value is %d.",
scale_h));
}
}
if (scale_h > 0. && scale_w > 0.) {
out_h = static_cast<int>(in_h * scale_h);
out_w = static_cast<int>(in_w * scale_w);
}
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
NpuOpRunner runner;
// To-do(qili93): need to support bilineare, try ResizeGradD
if ("nearest" == interp_method) {
runner.SetType("ResizeNearestNeighborV2Grad")
.AddInput(*output_grad)
.AddInput(std::vector<int32_t>{in_h, in_w})
.AddOutput(*input_grad)
.AddAttr("align_corners", align_corners)
.AddAttr("half_pixel_centers", false);
}
runner.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
nearest_interp_v2,
ops::InterpolateV2NPUKernel<plat::NPUDeviceContext, float>,
ops::InterpolateV2NPUKernel<plat::NPUDeviceContext, plat::float16>);
REGISTER_OP_NPU_KERNEL(
nearest_interp_v2_grad,
ops::InterpolateV2NPUGradKernel<plat::NPUDeviceContext, float>,
ops::InterpolateV2NPUGradKernel<plat::NPUDeviceContext, plat::float16>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn as nn
import paddle
from paddle.nn.functional import interpolate
from test_nearest_interp_v2_op import nearest_neighbor_interp_np
paddle.enable_static()
class TestNearestInterpOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def setUp(self):
self.set_npu()
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "nearest_interp_v2"
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
scale_h = 0
scale_w = 0
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
output_h = int(in_h * scale_h)
output_w = int(in_w * scale_w)
else:
output_h = self.out_h
output_w = self.out_w
output_np = nearest_neighbor_interp_np(
input_np, output_h, output_w, scale_h, scale_w, self.out_size,
self.actual_shape, self.align_corners, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'data_layout': self.data_layout
}
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X'], 'Out', in_place=True, max_relative_error=0.006)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 4, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = False
class TestNearestNeighborInterpCase1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = False
class TestNearestNeighborInterpCase2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = False
class TestNearestNeighborInterpCase3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = False
class TestNearestNeighborInterpCase4(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = False
class TestNearestNeighborInterpCase5(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = False
class TestNearestNeighborInterpCase6(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([65, 129]).astype("int32")
self.align_corners = False
class TestNearestNeighborInterpSame(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.align_corners = False
class TestNearestNeighborInterpActualShape(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = False
class TestNearestNeighborInterpScale1(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 7, 5]
self.out_h = 64
self.out_w = 32
self.scale = 2.
self.out_size = None
self.align_corners = False
class TestNearestNeighborInterpScale2(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 5, 7]
self.out_h = 64
self.out_w = 32
self.scale = 1.5
self.out_size = None
self.align_corners = False
class TestNearestNeighborInterpScale3(TestNearestInterpOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 7, 5]
self.out_h = 64
self.out_w = 32
self.scale = [2.0, 3.0]
self.out_size = None
self.align_corners = False
class TestNearestInterpOp_attr_tensor(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def setUp(self):
self.set_npu()
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "nearest_interp_v2"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
}
input_np = np.random.random(self.input_shape).astype("float32")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
scale_h = scale_w = float(self.scale)
if isinstance(self.scale, list) and len(self.scale) == 1:
scale_w = scale_h = self.scale[0]
elif isinstance(self.scale, list) and len(self.scale) > 1:
scale_w = self.scale[1]
scale_h = self.scale[0]
out_h = int(self.input_shape[2] * scale_h)
out_w = int(self.input_shape[3] * scale_w)
else:
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
if self.scale:
if isinstance(self.scale, float) or isinstance(self.scale, int):
if self.scale > 0:
self.scale = [self.scale]
if isinstance(self.scale, list) and len(self.scale) == 1:
self.scale = [self.scale[0], self.scale[0]]
self.attrs['scale'] = self.scale
output_np = nearest_neighbor_interp_np(input_np, out_h, out_w, 0, 0,
self.out_size, self.actual_shape,
self.align_corners)
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 5, 4, 4]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [3, 3]
self.align_corners = False
# out_size is a tensor list
class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = [8, 12]
self.align_corners = False
# out_size is a 1-D tensor
class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = False
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.0
self.out_size = None
self.align_corners = False
self.scale_by_1Dtensor = True
class TestNearestInterpOpAPI_dy(unittest.TestCase):
def test_case(self):
import paddle
if core.is_compiled_with_npu():
place = core.NPUPlace(0)
else:
place = core.CPUPlace()
with fluid.dygraph.guard(place):
input_data = np.random.random((2, 3, 6, 6)).astype("float32")
scale_np = np.array([2, 2]).astype("int64")
input_x = paddle.to_tensor(input_data)
scale = paddle.to_tensor(scale_np)
expect_res = nearest_neighbor_interp_np(
input_data, out_h=12, out_w=12, align_corners=False)
out = interpolate(
x=input_x,
scale_factor=scale,
mode="nearest",
align_corners=False)
self.assertTrue(np.allclose(out.numpy(), expect_res))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册