未验证 提交 f71f77e9 编写于 作者: J jameszhang 提交者: GitHub

[KUNLUN] add op: maxpool_with_index (#49505)

* [KUNLUN] add op: maxpool_with_index

* use DeviceContext::Alloc() instead of DenseTensor::mutable_data()

* fix file format

* solve clip unittest failure

* minor fix

* Revert "solve clip unittest failure" since the issue is fixed
in #49535

This reverts commit 1127adc66e79afe35ac3c00bb34e6aaa7cd7d78b.

* align with xdnn on the definition of mask in max_pool_with_index

* minor
上级 f65ca8ca
......@@ -337,6 +337,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
pair.first != "X") {
continue;
}
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first == "Mask") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
......@@ -381,6 +386,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first != "Mask" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
dst_type == framework::proto::VarType::FP32) {
if (pair.first != "LnScale" && pair.first != "LnBias" &&
......@@ -428,6 +438,11 @@ NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
pair.first != "X") {
continue;
}
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first == "Mask") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
......
......@@ -373,6 +373,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"max_pool2d_with_index",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"max_pool2d_with_index_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"matmul_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"matmul_v2_grad",
......
......@@ -104,7 +104,6 @@ void Pool2dGradKernel(const Context& ctx,
}
if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
r = xpu::max_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
......@@ -142,6 +141,67 @@ void Pool2dGradKernel(const Context& ctx,
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad");
}
template <typename T, typename Context>
void MaxPool2dWithIndexGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& dout,
const std::vector<int>& kernel_size,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool global_pooling,
bool adaptive,
DenseTensor* dx) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(dx);
auto input_grad = reinterpret_cast<XPUType*>(dx->data<T>());
std::vector<int> ksize(kernel_size);
std::vector<int> strides(strides_t);
std::vector<int> paddings(paddings_t);
const auto* index_data = mask.data<int>();
PADDLE_ENFORCE_NOT_NULL(index_data,
errors::NotFound("index data should not be nullptr"));
PADDLE_ENFORCE_EQ(
ksize.size(),
2,
phi::errors::InvalidArgument("The Pool2d XPU OP only support 2 "
"dimension pooling!, but received "
"%d-dimension pool kernel size",
ksize.size()));
global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(dx->dims()[i + 2]);
}
}
const int n = dx->dims()[0];
const int c = dx->dims()[1];
const int in_h = dx->dims()[2];
const int in_w = dx->dims()[3];
auto output_grad = reinterpret_cast<const XPUType*>(dout.data<T>());
int r = xpu::Error_t::SUCCESS;
// pass a nullptr as input to XDNN is fine as long as index_data exists
r = xpu::max_pool2d_grad<XPUType>(ctx.x_context(),
/*input*/ nullptr,
/*output*/ nullptr,
index_data,
output_grad,
input_grad,
n,
c,
in_h,
in_w,
ksize,
strides,
paddings,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(pool2d_grad,
......@@ -150,3 +210,9 @@ PD_REGISTER_KERNEL(pool2d_grad,
phi::Pool2dGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
XPU,
ALL_LAYOUT,
phi::MaxPool2dWithIndexGradKernel,
float,
phi::dtype::float16) {}
......@@ -154,7 +154,72 @@ void Pool2dKernel(const Context& ctx,
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2d");
}
template <typename T, typename Context>
void MaxPool2dWithIndexKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool global_pooling,
bool adaptive,
DenseTensor* out,
DenseTensor* mask) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<int>(mask);
auto* index_data = mask->data<int>();
std::vector<int> ksize(kernel_size);
std::vector<int> strides(strides_t);
std::vector<int> paddings(paddings_t);
PADDLE_ENFORCE_EQ(ksize.size(),
2,
phi::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1),
true,
phi::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(x.dims()[i + 2]);
}
}
const int n = x.dims()[0];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
auto input = reinterpret_cast<const XPUType*>(x.data<T>());
ctx.template Alloc<T>(out);
auto output = reinterpret_cast<XPUType*>(out->data<T>());
int r = xpu::Error_t::SUCCESS;
r = xpu::max_pool2d<XPUType>(ctx.x_context(),
input,
output,
index_data,
n,
c,
in_h,
in_w,
ksize,
strides,
paddings,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index");
}
} // namespace phi
PD_REGISTER_KERNEL(
pool2d, XPU, ALL_LAYOUT, phi::Pool2dKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index,
XPU,
ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel,
float,
phi::dtype::float16) {}
......@@ -29,6 +29,7 @@ WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
paddle.enable_static()
def max_pool2D_forward_naive(
x, ksize, strides, paddings, global_pool=False, adaptive=False
):
N, C, H, W = x.shape
global_pool = global_pool or (adaptive or (ksize[0] * ksize[1] == 1))
if global_pool:
ksize = [H, W]
paddings = [0, 0]
H_out = (H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
mask = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
r0 = i * strides[0] - paddings[0]
r1 = r0 + ksize[0]
c0 = j * strides[1] - paddings[1]
c1 = c0 + ksize[1]
r_start = np.max((r0, 0))
r_end = np.min((r1, H))
c_start = np.max((c0, 0))
c_end = np.min((c1, W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
for n in range(N):
for c in range(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][-1] - r0 if r0 < 0 else index[0][-1]
sub_col = index[1][-1] - c0 if c0 < 0 else index[1][-1]
index = sub_row * (r1 - r0) + sub_col
mask[n, c, i, j] = index
return out, mask
class XPUTestPoolWithIndex_op(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'max_pool2d_with_index'
self.use_dynamic_create_class = False
class TestMaxPoolWithIndex_Op(XPUOpTest):
def setUp(self):
self.op_type = 'max_pool2d_with_index'
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.init_test_case()
self.init_global()
self.init_adaptive()
input = np.random.random(self.shape).astype(self.dtype)
input = np.round(input * 100.0, 2)
output, mask = self.pool_forward_naive(
input,
self.ksize,
self.strides,
self.paddings,
self.global_pool,
self.adaptive,
)
output = output.astype(self.dtype)
mask = mask.astype("int32")
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'global_pooling': self.global_pool,
'adaptive': self.adaptive,
}
self.inputs = {'X': input}
self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, set(['X']), ['Out'])
def init_test_case(self):
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [1, 1]
def init_global(self):
self.global_pool = False
def init_adaptive(self):
self.adaptive = False
# TODO pool3d is not supported for now
# ----------------max_pool2d_with_index----------------
class TestCase4(TestMaxPoolWithIndex_Op):
def init_test_case(self):
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
def init_global(self):
self.global_pool = True
class TestCase5(TestCase4):
def init_global(self):
self.global_pool = False
class TestCase6(TestMaxPoolWithIndex_Op):
def init_test_case(self):
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
def init_global(self):
self.global_pool = True
class TestCase7(TestCase6):
def init_global(self):
self.global_pool = False
support_types = get_xpu_op_support_types('max_pool2d_with_index')
for stype in support_types:
create_test_class(globals(), XPUTestPoolWithIndex_op, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册