From f71f77e9f7bdf67d4a1beb7209769bedeea64e2c Mon Sep 17 00:00:00 2001 From: jameszhang Date: Thu, 19 Jan 2023 13:59:02 +0800 Subject: [PATCH] [KUNLUN] add op: maxpool_with_index (#49505) * [KUNLUN] add op: maxpool_with_index * use DeviceContext::Alloc() instead of DenseTensor::mutable_data() * fix file format * solve clip unittest failure * minor fix * Revert "solve clip unittest failure" since the issue is fixed in #49535 This reverts commit 1127adc66e79afe35ac3c00bb34e6aaa7cd7d78b. * align with xdnn on the definition of mask in max_pool_with_index * minor --- paddle/fluid/imperative/amp_auto_cast.cc | 15 ++ paddle/phi/backends/xpu/xpu2_op_list.cc | 4 + paddle/phi/kernels/xpu/pool_grad_kernel.cc | 68 ++++++- paddle/phi/kernels/xpu/pool_kernel.cc | 65 +++++++ python/paddle/amp/auto_cast.py | 1 + .../tests/unittests/xpu/test_pool_max_op.py | 171 ++++++++++++++++++ 6 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_pool_max_op.py diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 55c1520820..bf428ddf9b 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -337,6 +337,11 @@ NameVarMap AutoCastInputs(const std::string& op_type, pair.first != "X") { continue; } + if ((op_type == "max_pool2d_with_index_grad" || + op_type == "max_pool2d_with_index") && + pair.first == "Mask") { + continue; + } if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { if (pair.first == "LnScale" || pair.first == "LnBias" || @@ -381,6 +386,11 @@ NameVarMap AutoCastInputs(const std::string& op_type, pair.first == "X" && dst_type == framework::proto::VarType::FP32) { continue; } + if ((op_type == "max_pool2d_with_index_grad" || + op_type == "max_pool2d_with_index") && + pair.first != "Mask" && dst_type == framework::proto::VarType::FP32) { + continue; + } if ((op_type == "fused_attention" || op_type == "fused_feedforwad") && dst_type == framework::proto::VarType::FP32) { if (pair.first != "LnScale" && pair.first != "LnBias" && @@ -428,6 +438,11 @@ NameVarMap CastPureFp16Inputs(const std::string& op_type, pair.first != "X") { continue; } + if ((op_type == "max_pool2d_with_index_grad" || + op_type == "max_pool2d_with_index") && + pair.first == "Mask") { + continue; + } if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { if (pair.first == "LnScale" || pair.first == "LnBias" || pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 8451ee2774..99cb79035b 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -373,6 +373,10 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"max_pool2d_with_index", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"max_pool2d_with_index_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"matmul_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"matmul_v2_grad", diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index 349fe1a0f1..3ae139bdd4 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -104,7 +104,6 @@ void Pool2dGradKernel(const Context& ctx, } if (pooling_type == "max") { - // TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api r = xpu::max_pool2d_grad( ctx.x_context(), reinterpret_cast(x.data()), @@ -142,6 +141,67 @@ void Pool2dGradKernel(const Context& ctx, } PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad"); } + +template +void MaxPool2dWithIndexGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& mask, + const DenseTensor& dout, + const std::vector& kernel_size, + const std::vector& strides_t, + const std::vector& paddings_t, + bool global_pooling, + bool adaptive, + DenseTensor* dx) { + using XPUType = typename XPUTypeTrait::Type; + + ctx.template Alloc(dx); + auto input_grad = reinterpret_cast(dx->data()); + std::vector ksize(kernel_size); + std::vector strides(strides_t); + std::vector paddings(paddings_t); + const auto* index_data = mask.data(); + + PADDLE_ENFORCE_NOT_NULL(index_data, + errors::NotFound("index data should not be nullptr")); + PADDLE_ENFORCE_EQ( + ksize.size(), + 2, + phi::errors::InvalidArgument("The Pool2d XPU OP only support 2 " + "dimension pooling!, but received " + "%d-dimension pool kernel size", + ksize.size())); + global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1)); + if (global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(dx->dims()[i + 2]); + } + } + const int n = dx->dims()[0]; + const int c = dx->dims()[1]; + const int in_h = dx->dims()[2]; + const int in_w = dx->dims()[3]; + auto output_grad = reinterpret_cast(dout.data()); + + int r = xpu::Error_t::SUCCESS; + // pass a nullptr as input to XDNN is fine as long as index_data exists + r = xpu::max_pool2d_grad(ctx.x_context(), + /*input*/ nullptr, + /*output*/ nullptr, + index_data, + output_grad, + input_grad, + n, + c, + in_h, + in_w, + ksize, + strides, + paddings, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index_grad"); +} } // namespace phi PD_REGISTER_KERNEL(pool2d_grad, @@ -150,3 +210,9 @@ PD_REGISTER_KERNEL(pool2d_grad, phi::Pool2dGradKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(max_pool2d_with_index_grad, + XPU, + ALL_LAYOUT, + phi::MaxPool2dWithIndexGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index 9278484378..92a8d48d1a 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -154,7 +154,72 @@ void Pool2dKernel(const Context& ctx, } PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2d"); } + +template +void MaxPool2dWithIndexKernel(const Context& ctx, + const DenseTensor& x, + const std::vector& kernel_size, + const std::vector& strides_t, + const std::vector& paddings_t, + bool global_pooling, + bool adaptive, + DenseTensor* out, + DenseTensor* mask) { + using XPUType = typename XPUTypeTrait::Type; + + ctx.template Alloc(mask); + auto* index_data = mask->data(); + + std::vector ksize(kernel_size); + std::vector strides(strides_t); + std::vector paddings(paddings_t); + + PADDLE_ENFORCE_EQ(ksize.size(), + 2, + phi::errors::InvalidArgument( + "The Pool2d XPU OP only support 2 dimension pooling!")); + PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), + true, + phi::errors::InvalidArgument( + "The Pool2d XPU OP does not support (adaptive == " + "true && output_size != 1)")); + global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1)); + if (global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(x.dims()[i + 2]); + } + } + const int n = x.dims()[0]; + const int c = x.dims()[1]; + const int in_h = x.dims()[2]; + const int in_w = x.dims()[3]; + auto input = reinterpret_cast(x.data()); + ctx.template Alloc(out); + auto output = reinterpret_cast(out->data()); + int r = xpu::Error_t::SUCCESS; + r = xpu::max_pool2d(ctx.x_context(), + input, + output, + index_data, + n, + c, + in_h, + in_w, + ksize, + strides, + paddings, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index"); +} } // namespace phi PD_REGISTER_KERNEL( pool2d, XPU, ALL_LAYOUT, phi::Pool2dKernel, float, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(max_pool2d_with_index, + XPU, + ALL_LAYOUT, + phi::MaxPool2dWithIndexKernel, + float, + phi::dtype::float16) {} diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 9f0b6ac269..6d0bc89296 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -29,6 +29,7 @@ WHITE_LIST = { 'conv2d', 'matmul', 'matmul_v2', + 'max_pool2d_with_index', 'mul', 'fake_quantize_dequantize_abs_max', 'fake_quantize_dequantize_moving_average_abs_max', diff --git a/python/paddle/fluid/tests/unittests/xpu/test_pool_max_op.py b/python/paddle/fluid/tests/unittests/xpu/test_pool_max_op.py new file mode 100644 index 0000000000..9d27dcc760 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_pool_max_op.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +sys.path.append("..") +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) + +import paddle + +paddle.enable_static() + + +def max_pool2D_forward_naive( + x, ksize, strides, paddings, global_pool=False, adaptive=False +): + + N, C, H, W = x.shape + global_pool = global_pool or (adaptive or (ksize[0] * ksize[1] == 1)) + if global_pool: + ksize = [H, W] + paddings = [0, 0] + + H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + mask = np.zeros((N, C, H_out, W_out)) + for i in range(H_out): + for j in range(W_out): + r0 = i * strides[0] - paddings[0] + r1 = r0 + ksize[0] + c0 = j * strides[1] - paddings[1] + c1 = c0 + ksize[1] + r_start = np.max((r0, 0)) + r_end = np.min((r1, H)) + c_start = np.max((c0, 0)) + c_end = np.min((c1, W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + + for n in range(N): + for c in range(C): + arr = x_masked[n, c, :, :] + index = np.where(arr == np.max(arr)) + sub_row = index[0][-1] - r0 if r0 < 0 else index[0][-1] + sub_col = index[1][-1] - c0 if c0 < 0 else index[1][-1] + index = sub_row * (r1 - r0) + sub_col + mask[n, c, i, j] = index + + return out, mask + + +class XPUTestPoolWithIndex_op(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'max_pool2d_with_index' + self.use_dynamic_create_class = False + + class TestMaxPoolWithIndex_Op(XPUOpTest): + def setUp(self): + self.op_type = 'max_pool2d_with_index' + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.init_test_case() + self.init_global() + self.init_adaptive() + + input = np.random.random(self.shape).astype(self.dtype) + input = np.round(input * 100.0, 2) + output, mask = self.pool_forward_naive( + input, + self.ksize, + self.strides, + self.paddings, + self.global_pool, + self.adaptive, + ) + output = output.astype(self.dtype) + mask = mask.astype("int32") + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'global_pooling': self.global_pool, + 'adaptive': self.adaptive, + } + + self.inputs = {'X': input} + self.outputs = {'Out': output, "Mask": mask} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, set(['X']), ['Out']) + + def init_test_case(self): + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [1, 1] + + def init_global(self): + self.global_pool = False + + def init_adaptive(self): + self.adaptive = False + + # TODO pool3d is not supported for now + # ----------------max_pool2d_with_index---------------- + class TestCase4(TestMaxPoolWithIndex_Op): + def init_test_case(self): + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_global(self): + self.global_pool = True + + class TestCase5(TestCase4): + def init_global(self): + self.global_pool = False + + class TestCase6(TestMaxPoolWithIndex_Op): + def init_test_case(self): + self.op_type = "max_pool2d_with_index" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + def init_global(self): + self.global_pool = True + + class TestCase7(TestCase6): + def init_global(self): + self.global_pool = False + + +support_types = get_xpu_op_support_types('max_pool2d_with_index') +for stype in support_types: + create_test_class(globals(), XPUTestPoolWithIndex_op, stype) + + +if __name__ == '__main__': + unittest.main() -- GitLab