未验证 提交 f138371c 编写于 作者: F fwenguang 提交者: GitHub

[MLU] support adative pooling (#39500)

上级 a7d4ddc4
......@@ -1151,6 +1151,18 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
output_desc, output, workspace_ptr, workspace_size));
}
/* static */ void MLUCnnl::AdaptivePoolingForward(
const ExecutionContext& ctx, cnnlPoolingMode_t pool_mode,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t output_desc, void* output,
const cnnlTensorDescriptor_t index_desc, void* index) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlAdaptivePoolingForward(handle, input_desc, input, pool_mode,
output_desc, output, index_desc, index));
}
/* static */ void MLUCnnl::Pool3D(
const ExecutionContext& ctx, cnnlPoolingMode_t pool_mode,
const std::vector<int64_t>& output_shape,
......@@ -1802,6 +1814,17 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
y, diff_y_desc, diff_y, x_desc, x, beta, diff_x_desc, diff_x));
}
/* static */ void MLUCnnl::AdaptivePoolingBackward(
const ExecutionContext& ctx, const cnnlPoolingMode_t pool_mode,
const cnnlTensorDescriptor_t y_desc, const void* y,
const cnnlTensorDescriptor_t index_desc, const void* index,
const cnnlTensorDescriptor_t diff_x_desc, void* diff_x) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlAdaptivePoolingBackward(
handle, y_desc, y, index_desc, index, pool_mode, diff_x_desc, diff_x));
}
/* static */ void MLUCnnl::NonMaxSuppression(
const ExecutionContext& ctx, const cnnlNmsDescriptor_t nms_desc,
const cnnlTensorDescriptor_t boxes_desc, const void* boxes,
......
......@@ -649,6 +649,12 @@ class MLUCnnl {
const void* input, const void* beta, const void* extra_input_ptr,
const cnnlTensorDescriptor_t output_desc, void* output);
static void AdaptivePoolingForward(
const ExecutionContext& ctx, cnnlPoolingMode_t pool_mode,
const cnnlTensorDescriptor_t input_desc, const void* input,
const cnnlTensorDescriptor_t output_desc, void* output,
const cnnlTensorDescriptor_t index_desc, void* index);
static void Pool3D(const ExecutionContext& ctx, cnnlPoolingMode_t pool_mode,
const std::vector<int64_t>& output_shape,
cnnlPoolingDescriptor_t pooling_desc, const void* alpha,
......@@ -958,6 +964,12 @@ class MLUCnnl {
const cnnlTensorDescriptor_t x_desc, const void* x, const void* beta,
const cnnlTensorDescriptor_t diff_x_desc, void* diff_x);
static void AdaptivePoolingBackward(
const ExecutionContext& ctx, const cnnlPoolingMode_t pool_mode,
const cnnlTensorDescriptor_t y_desc, const void* y,
const cnnlTensorDescriptor_t index_desc, const void* index,
const cnnlTensorDescriptor_t diff_x_desc, void* diff_x);
static void PoolingIndex(const ExecutionContext& ctx,
const cnnlPoolingDescriptor_t pooling_desc,
const cnnlTensorDescriptor_t x_desc, const void* x,
......
......@@ -21,12 +21,12 @@ namespace operators {
namespace {
cnnlPoolingMode_t ToCnnlPoolingMode(const std::string &pooling_type,
bool exclusive) {
bool exclusive, bool adaptive) {
cnnlPoolingMode_t pooling_mode;
if (pooling_type == "max") {
pooling_mode = CNNL_POOLING_MAX;
} else if (pooling_type == "avg") {
if (exclusive) {
if (exclusive && !adaptive) {
pooling_mode = CNNL_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
} else {
pooling_mode = CNNL_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
......@@ -64,10 +64,7 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument(
"Only support 4-dims for mlu pool2d kernel."));
PADDLE_ENFORCE_EQ(adaptive, false,
platform::errors::InvalidArgument(
"Not support adaptive for mlu pool2d kernel."));
const bool channel_last = data_format == "NHWC";
// default
cnnlTensorLayout_t cnnl_layout = CNNL_LAYOUT_NCHW;
auto out_dims = out->dims();
......@@ -77,7 +74,6 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
framework::DDim data_dims =
framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
const bool channel_last = data_format == "NHWC";
if (channel_last) {
cnnl_layout = CNNL_LAYOUT_NHWC;
out_h = out_dims[1];
......@@ -94,42 +90,74 @@ class MLUPoolOpKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc in_x_desc(*in_x, cnnl_layout, ToCnnlDataType<T>());
MLUCnnlTensorDesc out_desc(*out, cnnl_layout, ToCnnlDataType<T>());
cnnlPoolingMode_t pool_mode = ToCnnlPoolingMode(pooling_type, exclusive);
MLUCnnlPoolingDesc pool_desc(
pool_mode, CNNL_NOT_PROPAGATE_NAN, ksize[0], ksize[1], paddings[0],
paddings[1], paddings[2], paddings[3], strides[0], strides[1],
1 /*row_dilation*/, 1 /*col_dilation*/, ceil_mode);
cnnlPoolingMode_t pool_mode =
ToCnnlPoolingMode(pooling_type, exclusive, adaptive);
if (!adaptive) {
MLUCnnlPoolingDesc pool_desc(
pool_mode, CNNL_NOT_PROPAGATE_NAN, ksize[0], ksize[1], paddings[0],
paddings[1], paddings[2], paddings[3], strides[0], strides[1],
1 /*row_dilation*/, 1 /*col_dilation*/, ceil_mode);
size_t extra_input_size = 0;
cnnlHandle_t handle =
ctx.template device_context<MLUDeviceContext>().cnnl_handle();
cnnlGetPoolingExtraInputSize(handle, pool_mode, out_w, out_h,
&extra_input_size);
size_t extra_input_size = 0;
cnnlHandle_t handle =
ctx.template device_context<MLUDeviceContext>().cnnl_handle();
cnnlGetPoolingExtraInputSize(handle, pool_mode, out_w, out_h,
&extra_input_size);
if (extra_input_size > 0) {
paddle::platform::CPUDeviceContext cpu_ctx;
framework::Tensor extra_host_tensor =
ctx.AllocateTmpTensor<int8_t, platform::CPUDeviceContext>(
{static_cast<int64_t>(extra_input_size)}, cpu_ctx);
cnnlInitPoolingExtraInput(handle, pool_desc.get(), in_x_desc.get(),
out_desc.get(), GetBasePtr(&extra_host_tensor));
framework::Tensor extra_device_tensor =
ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(extra_input_size)}, dev_ctx);
// TODO(fwg): use Async copy, and add a callback to stream that free host
// memory.
framework::TensorCopySync(extra_host_tensor, ctx.GetPlace(),
&extra_device_tensor);
MLUCnnl::PoolingForward(
ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/,
in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/,
GetBasePtr(&extra_device_tensor) /*params_shape_ptr*/, out_desc.get(),
GetBasePtr(out));
if (extra_input_size > 0) {
paddle::platform::CPUDeviceContext cpu_ctx;
framework::Tensor extra_host_tensor =
ctx.AllocateTmpTensor<int8_t, platform::CPUDeviceContext>(
{static_cast<int64_t>(extra_input_size)}, cpu_ctx);
cnnlInitPoolingExtraInput(handle, pool_desc.get(), in_x_desc.get(),
out_desc.get(),
GetBasePtr(&extra_host_tensor));
framework::Tensor extra_device_tensor =
ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(extra_input_size)}, dev_ctx);
// TODO(fwg): use Async copy, and add a callback to stream that free
// host
// memory.
framework::TensorCopySync(extra_host_tensor, ctx.GetPlace(),
&extra_device_tensor);
MLUCnnl::PoolingForward(
ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/,
in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/,
GetBasePtr(&extra_device_tensor) /*params_shape_ptr*/,
out_desc.get(), GetBasePtr(out));
} else {
MLUCnnl::PoolingForward(
ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/,
in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/,
nullptr /*params_shape_ptr*/, out_desc.get(), GetBasePtr(out));
}
} else {
MLUCnnl::PoolingForward(
ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/,
in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/,
nullptr /*params_shape_ptr*/, out_desc.get(), GetBasePtr(out));
// cnnl Adaptive pooling only support NHWC layout
framework::Tensor trans_in_x;
framework::Tensor trans_out;
if (channel_last) {
trans_in_x = *in_x;
trans_out = *out;
} else {
std::vector<int> perm{0, 2, 3, 1};
TransposeFromMLUTensor<T>(ctx, perm, in_x, &trans_in_x,
true /*need_reshape_or_alloc*/);
trans_out = ctx.AllocateTmpTensor<T, MLUDeviceContext>(
{out_dims[0], out_dims[2], out_dims[3], out_dims[1]}, dev_ctx);
}
MLUCnnlTensorDesc trans_in_x_desc(trans_in_x, CNNL_LAYOUT_NHWC,
ToCnnlDataType<T>());
MLUCnnlTensorDesc trans_out_desc(trans_out, CNNL_LAYOUT_NHWC,
ToCnnlDataType<T>());
MLUCnnl::AdaptivePoolingForward(
ctx, pool_mode, trans_in_x_desc.get(), GetBasePtr(&trans_in_x),
trans_out_desc.get(), GetBasePtr(&trans_out), nullptr, nullptr);
if (!channel_last) {
std::vector<int> perm{0, 3, 1, 2};
TransposeFromMLUTensor<T>(ctx, perm, &trans_out, out,
false /*need_reshape_or_alloc*/);
}
}
}
};
......@@ -204,7 +232,8 @@ class MLUPoolGradOpKernel : public framework::OpKernel<T> {
MLUCnnlTensorDesc trans_in_x_grad_desc(trans_in_x_grad, CNNL_LAYOUT_NHWC,
ToCnnlDataType<T>());
cnnlPoolingMode_t pool_mode = ToCnnlPoolingMode(pooling_type, exclusive);
cnnlPoolingMode_t pool_mode =
ToCnnlPoolingMode(pooling_type, exclusive, adaptive);
MLUCnnlPoolingDesc pool_desc(
pool_mode, CNNL_NOT_PROPAGATE_NAN, ksize[0], ksize[1], paddings[0],
paddings[1], paddings[2], paddings[3], strides[0], strides[1],
......@@ -219,18 +248,34 @@ class MLUPoolGradOpKernel : public framework::OpKernel<T> {
MLUCnnl::PoolingIndex(ctx, pool_desc.get(), trans_in_x_desc.get(),
GetBasePtr(&trans_in_x), index_tensor_desc.get(),
GetBasePtr(&index_tensor));
MLUCnnl::PoolingBackward(
ctx, pool_desc.get(), nullptr /*alpha*/, index_tensor_desc.get(),
GetBasePtr(&index_tensor), trans_out_grad_desc.get(),
GetBasePtr(&trans_out_grad), trans_in_x_desc.get(),
GetBasePtr(&trans_in_x), nullptr /*beta*/, trans_in_x_grad_desc.get(),
GetBasePtr(&trans_in_x_grad));
if (adaptive) {
MLUCnnl::AdaptivePoolingBackward(
ctx, pool_mode, trans_out_grad_desc.get(),
GetBasePtr(&trans_out_grad), index_tensor_desc.get(),
GetBasePtr(&index_tensor), trans_in_x_grad_desc.get(),
GetBasePtr(&trans_in_x_grad));
} else {
MLUCnnl::PoolingBackward(
ctx, pool_desc.get(), nullptr /*alpha*/, index_tensor_desc.get(),
GetBasePtr(&index_tensor), trans_out_grad_desc.get(),
GetBasePtr(&trans_out_grad), trans_in_x_desc.get(),
GetBasePtr(&trans_in_x), nullptr /*beta*/,
trans_in_x_grad_desc.get(), GetBasePtr(&trans_in_x_grad));
}
} else {
MLUCnnl::PoolingBackward(ctx, pool_desc.get(), nullptr /*alpha*/, nullptr,
nullptr, trans_out_grad_desc.get(),
GetBasePtr(&trans_out_grad), nullptr, nullptr,
nullptr /*beta*/, trans_in_x_grad_desc.get(),
GetBasePtr(&trans_in_x_grad));
if (adaptive) {
MLUCnnl::AdaptivePoolingBackward(
ctx, pool_mode, trans_out_grad_desc.get(),
GetBasePtr(&trans_out_grad), nullptr /*index_tensor_desc.get()*/,
nullptr /*GetBasePtr(&index_tensor)*/, trans_in_x_grad_desc.get(),
GetBasePtr(&trans_in_x_grad));
} else {
MLUCnnl::PoolingBackward(ctx, pool_desc.get(), nullptr /*alpha*/,
nullptr, nullptr, trans_out_grad_desc.get(),
GetBasePtr(&trans_out_grad), nullptr, nullptr,
nullptr /*beta*/, trans_in_x_grad_desc.get(),
GetBasePtr(&trans_in_x_grad));
}
}
if (!channel_last) {
std::vector<int> perm{0, 3, 1, 2};
......
......@@ -25,7 +25,125 @@ from paddle.fluid import Program, program_guard
import sys
sys.path.append('..')
from op_test import OpTest
from test_pool2d_op import pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive
from test_pool2d_op import pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive, adaptive_start_index, adaptive_end_index
def pool2d_backward_navie(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
data_format='NCHW',
pool_type="max",
padding_algorithm="EXPLICIT"):
# update paddings
def _get_padding_with_SAME(input_shape, pool_size, pool_stride):
padding = []
for input_size, filter_size, stride_size in zip(input_shape, pool_size,
pool_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max((
(out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
if isinstance(padding_algorithm, str):
padding_algorithm = padding_algorithm.upper()
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if padding_algorithm == "VALID":
paddings = [0, 0, 0, 0]
if ceil_mode != False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode)"
" must be False. "
"Received ceil_mode: True.")
elif padding_algorithm == "SAME":
input_data_shape = []
if data_format == "NCHW":
input_data_shape = x.shape[2:4]
elif data_format == "NHWC":
input_data_shape = x.shape[1:3]
paddings = _get_padding_with_SAME(input_data_shape, ksize, strides)
assert len(paddings) == 2 or len(paddings) == 4
is_sys = True if len(paddings) == 2 else False
if data_format == "NHWC":
x = x.transpose([0, 3, 1, 2])
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
paddings = [0 for _ in range(len(paddings))]
pad_h_up = paddings[0] if is_sys else paddings[0]
pad_h_down = paddings[0] if is_sys else paddings[1]
pad_w_left = paddings[1] if is_sys else paddings[2]
pad_w_right = paddings[1] if is_sys else paddings[3]
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + pad_h_up + pad_h_down + strides[0] - 1) // strides[0] + 1 \
if ceil_mode else (H - ksize[0] + pad_h_up + pad_h_down) // strides[0] + 1
W_out = (W - ksize[1] + pad_w_left + pad_w_right + strides[1] - 1) // strides[1] + 1 \
if ceil_mode else (W - ksize[1] + pad_w_left + pad_w_right) // strides[1] + 1
x_grad = np.zeros_like(x)
for i in range(H_out):
if adaptive:
in_h_start = adaptive_start_index(i, H, ksize[0])
in_h_end = adaptive_end_index(i, H, ksize[0])
else:
in_h_start = np.max((i * strides[0] - pad_h_up, 0))
in_h_end = np.min((i * strides[0] + ksize[0] - pad_h_up, H))
for j in range(W_out):
if adaptive:
in_w_start = adaptive_start_index(j, W, ksize[1])
in_w_end = adaptive_end_index(j, W, ksize[1])
else:
in_h_start = i * strides[0] - pad_h_up
in_w_start = j * strides[1] - pad_w_left
in_h_end = i * strides[0] + ksize[0] - pad_h_up
in_w_end = j * strides[1] + ksize[1] - pad_w_left
field_size = (in_h_end - in_h_start) * (in_w_end - in_w_start)
in_h_start = np.max((in_h_start, 0))
in_w_start = np.max((in_w_start, 0))
in_h_end = np.min((in_h_end, H))
in_w_end = np.min((in_w_end, W))
if pool_type == 'avg':
if (exclusive or adaptive):
field_size = (in_h_end - in_h_start) * (
in_w_end - in_w_start)
x_grad[:, :, in_h_start:in_h_end, in_w_start:
in_w_end] += 1 / field_size
elif pool_type == 'max':
for n in range(N):
for c in range(C):
idx = np.argmax(x[n, c, in_h_start:in_h_end, in_w_start:
in_w_end].flatten())
idx_h = idx // (in_w_end - in_w_start)
idx_w = idx % (in_w_end - in_w_start)
x_grad[n, c, in_h_start + idx_h, in_w_start +
idx_w] += 1
if data_format == "NHWC":
x_grad = x_grad.transpose([0, 2, 3, 1])
return x_grad
class TestPool2D_Op_Mixin(object):
......@@ -71,12 +189,25 @@ class TestPool2D_Op_Mixin(object):
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16:
return
if self.pool_type != "max":
self.check_grad_with_place(
self.place, set(['X']), 'Out', max_relative_error=0.07)
x_grad = pool2d_backward_navie(
self.inputs["X"],
ksize=self.ksize,
strides=self.strides,
paddings=self.paddings,
global_pool=self.global_pool,
ceil_mode=False,
exclusive=self.exclusive,
adaptive=self.adaptive,
data_format=self.data_format,
pool_type=self.pool_type,
padding_algorithm=self.padding_algorithm)
x_grad = x_grad / np.prod(self.outputs['Out'].shape)
self.check_grad_with_place(
self.place,
set(['X']),
'Out',
max_relative_error=0.06,
user_defined_grads=[x_grad])
def init_data_format(self):
self.data_format = "NCHW"
......@@ -108,7 +239,6 @@ class TestPool2D_Op_Mixin(object):
def init_exclusive(self):
self.exclusive = True
# Not support adaptive pooling currently
def init_adaptive(self):
self.adaptive = False
......@@ -173,7 +303,7 @@ class TestCase5(TestCase2):
self.pool2D_forward_naive = max_pool2D_forward_naive
def create_test_fp16_class(parent, check_grad=True):
def create_test_fp16_class(parent):
class TestFp16Case(parent):
def init_data_type(self):
self.dtype = np.float16
......@@ -182,19 +312,13 @@ def create_test_fp16_class(parent, check_grad=True):
place = core.MLUPlace(0)
self.check_output_with_place(place, atol=1e-3)
def test_check_grad(self):
place = core.MLUPlace(0)
if self.pool_type != "max" and check_grad:
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=0.07)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op")
TestFp16Case.__name__ = cls_name
globals()[cls_name] = TestFp16Case
create_test_fp16_class(TestPool2D_Op)
create_test_fp16_class(TestCase1, check_grad=False)
create_test_fp16_class(TestCase1)
create_test_fp16_class(TestCase2)
create_test_fp16_class(TestCase3)
create_test_fp16_class(TestCase4)
......@@ -222,6 +346,24 @@ class TestAvgInclude(TestCase2):
self.exclusive = False
class TestAvgPoolAdaptive(TestCase1):
def init_adaptive(self):
self.adaptive = True
class TestAvgPoolAdaptiveAsyOutSize(TestCase1):
def init_adaptive(self):
self.adaptive = True
def init_shape(self):
self.shape = [8, 3, 6, 6]
def init_test_case(self):
self.ksize = [2, 3]
self.strides = [1, 1]
self.paddings = [0, 0, 0, 0]
#-------test pool2d with asymmetric padding-----
......@@ -302,6 +444,19 @@ class TestAvgInclude_AsyPadding(TestCase2):
self.shape = [2, 3, 7, 7]
class TestAvgPoolAdaptive_AsyPadding(TestCase1):
def init_adaptive(self):
self.adaptive = True
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1, 0, 2]
def init_shape(self):
self.shape = [2, 3, 7, 7]
#----------- test channel_last --------------
class TestPool2D_channel_last(TestPool2D_Op):
def init_data_format(self):
......@@ -359,14 +514,6 @@ class TestCase5_Max(TestCase2):
def init_pool_type(self):
self.pool_type = "max"
def test_check_grad(self):
if self.dtype == np.float16:
return
place = core.MLUPlace(0)
if self.pool_type == "max":
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=1.00)
class TestCase5_channel_last_Max(TestCase5_Max):
def init_data_format(self):
......@@ -381,6 +528,11 @@ class TestAvgInclude_channel_last(TestCase2_channel_last):
self.exclusive = False
class TestAvgPoolAdaptive_channel_last(TestCase1_channel_last):
def init_adaptive(self):
self.adaptive = True
class TestPool2D_AsyPadding_channel_last(TestPool2D_AsyPadding):
def init_data_format(self):
self.data_format = "NHWC"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册