From da261732416e628f173fad386652c11c99d07af5 Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Tue, 24 Aug 2021 10:29:53 +0800 Subject: [PATCH] [NPU] add pool2 op and tests (#34770) * add pool2d_op_npu and test * update * update pool2d_backward_navie * clean headers --- paddle/fluid/operators/pool_op_npu.cc | 294 ++++++++ .../tests/unittests/npu/test_pool2d_op_npu.py | 686 ++++++++++++++++++ 2 files changed, 980 insertions(+) create mode 100644 paddle/fluid/operators/pool_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_pool2d_op_npu.py diff --git a/paddle/fluid/operators/pool_op_npu.cc b/paddle/fluid/operators/pool_op_npu.cc new file mode 100644 index 00000000000..b5eb8ae6178 --- /dev/null +++ b/paddle/fluid/operators/pool_op_npu.cc @@ -0,0 +1,294 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/pool_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +template +class NPUPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &dev_ctx = ctx.template device_context(); + const Tensor *in_x = ctx.Input("X"); + Tensor *out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + std::string pooling_type = ctx.Attr("pooling_type"); + std::vector ksize = ctx.Attr>("ksize"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::string data_format = ctx.Attr("data_format"); + + bool global_pooling = ctx.Attr("global_pooling"); + bool ceil_mode = ctx.Attr("ceil_mode"); + bool exclusive = ctx.Attr("exclusive"); + bool adaptive = ctx.Attr("adaptive"); + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + + const bool channel_last = data_format == "NHWC"; + + auto in_x_dims = in_x->dims(); + auto out_dims = out->dims(); + framework::DDim data_dims; + framework::DDim out_data_dims; + + Tensor in_x_tensor, out_tensor; + in_x_tensor.ShareDataWith(*in_x); + out_tensor.ShareDataWith(*out); + std::vector ksize_vec(4, 1); + std::vector strides_vec(4, 1); + + if (channel_last) { + data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1); + out_data_dims = framework::slice_ddim(out_dims, 1, out_dims.size() - 1); + ksize_vec[1] = ksize[0]; + ksize_vec[2] = ksize[1]; + strides_vec[1] = strides[0]; + strides_vec[2] = strides[1]; + in_x_tensor.set_layout(DataLayout::kNHWC); + out_tensor.set_layout(DataLayout::kNHWC); + } else { + data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size()); + out_data_dims = framework::slice_ddim(out_dims, 2, out_dims.size()); + ksize_vec[2] = ksize[0]; + ksize_vec[3] = ksize[1]; + strides_vec[2] = strides[0]; + strides_vec[3] = strides[1]; + } + UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm, + data_dims, strides, ksize); + PADDLE_ENFORCE_LT( + std::max(paddings[0], paddings[1]), ksize[0], + platform::errors::InvalidArgument( + "Paddings should be less than %d, but max(pads[0], pads[1]) is %d.", + ksize[0], std::max(paddings[0], paddings[1]))); + PADDLE_ENFORCE_LT( + std::max(paddings[2], paddings[3]), ksize[1], + platform::errors::InvalidArgument( + "Paddings should be less than %d, but max(pads[2], pads[3]) is %d.", + ksize[1], std::max(paddings[2], paddings[3]))); + + if (adaptive) { + std::string pooling_mode = "AdaptiveAvgPool2d"; + if (pooling_type == "max") { + pooling_mode = "AdaptiveMaxPool2d"; + } + + // AdaptiveAvgPool2d only support NCHW + Tensor transformed_input, transformed_output; + if (pooling_type == "avg" && channel_last) { + transformed_input.mutable_data( + framework::make_dim(in_x_dims[0], in_x_dims[3], in_x_dims[1], + in_x_dims[2]), + ctx.GetPlace()); + transformed_output.mutable_data( + framework::make_dim(out_dims[0], out_dims[3], out_dims[1], + out_dims[2]), + ctx.GetPlace()); + + const auto &trans_runner = + NpuOpRunner("TransData", {in_x_tensor}, {transformed_input}, + {{"src_format", std::string("NHWC")}, + {"dst_format", std::string("NCHW")}}); + trans_runner.Run(dev_ctx.stream()); + } else { + transformed_input.ShareDataWith(in_x_tensor); + transformed_output.ShareDataWith(out_tensor); + } + + const auto &runner = NpuOpRunner( + pooling_mode, {transformed_input}, {transformed_output}, + {{"output_size", framework::vectorize(out_data_dims)}}); + runner.Run(dev_ctx.stream()); + + if (pooling_type == "avg" && channel_last) { + const auto &trans_runner = + NpuOpRunner("TransData", {transformed_output}, {out_tensor}, + {{"src_format", std::string("NCHW")}, + {"dst_format", std::string("NHWC")}}); + trans_runner.Run(dev_ctx.stream()); + } + } else { + std::string pooling_mode = "AvgPoolV2"; + if (pooling_type == "max") { + PADDLE_ENFORCE_EQ( + exclusive, true, + platform::errors::InvalidArgument( + "MaxPool only support exclusive=false, but got true")); + pooling_mode = "MaxPoolV3"; + } + + const auto &runner = + NpuOpRunner(pooling_mode, {in_x_tensor}, {out_tensor}, + {{"ksize", ksize_vec}, + {"strides", strides_vec}, + {"padding_mode", std::string("CALCULATED")}, + {"pads", paddings}, + {"data_format", data_format}, + {"global_pooling", global_pooling}, + {"ceil_mode", ceil_mode}, + {"exclusive", exclusive}}); + runner.Run(dev_ctx.stream()); + } + } +}; + +template +class NPUPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &dev_ctx = ctx.template device_context(); + const Tensor *in_x = ctx.Input("X"); + const Tensor *out = ctx.Input("Out"); + const Tensor *out_grad = ctx.Input(framework::GradVarName("Out")); + Tensor *in_x_grad = ctx.Output(framework::GradVarName("X")); + in_x_grad->mutable_data(ctx.GetPlace()); + + std::string pooling_type = ctx.Attr("pooling_type"); + std::vector ksize = ctx.Attr>("ksize"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + bool ceil_mode = ctx.Attr("ceil_mode"); + bool exclusive = ctx.Attr("exclusive"); + bool adaptive = ctx.Attr("adaptive"); + std::string data_format = ctx.Attr("data_format"); + bool global_pooling = ctx.Attr("global_pooling"); + std::string padding_algorithm = ctx.Attr("padding_algorithm"); + + const bool channel_last = data_format == "NHWC"; + + // update paddings + auto in_x_dims = in_x->dims(); + auto out_dims = out->dims(); + framework::DDim data_dims; + framework::DDim out_data_dims; + std::vector ksize_vec(4, 1); + std::vector strides_vec(4, 1); + + Tensor in_x_tensor, out_tensor, out_grad_tensor, in_x_grad_tensor; + in_x_tensor.ShareDataWith(*in_x); + out_tensor.ShareDataWith(*out); + out_grad_tensor.ShareDataWith(*out_grad); + in_x_grad_tensor.ShareDataWith(*in_x_grad); + if (channel_last) { + data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1); + out_data_dims = framework::slice_ddim(out_dims, 1, out_dims.size() - 1); + ksize_vec[1] = ksize[0]; + ksize_vec[2] = ksize[1]; + strides_vec[1] = strides[0]; + strides_vec[2] = strides[1]; + in_x_tensor.set_layout(DataLayout::kNHWC); + out_tensor.set_layout(DataLayout::kNHWC); + out_grad_tensor.set_layout(DataLayout::kNHWC); + in_x_grad_tensor.set_layout(DataLayout::kNHWC); + } else { + data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size()); + out_data_dims = framework::slice_ddim(out_dims, 2, out_dims.size()); + ksize_vec[2] = ksize[0]; + ksize_vec[3] = ksize[1]; + strides_vec[2] = strides[0]; + strides_vec[3] = strides[1]; + } + UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm, + data_dims, strides, ksize); + + PADDLE_ENFORCE_LT( + std::max(paddings[0], paddings[1]), ksize[0], + platform::errors::InvalidArgument( + "Paddings should be less than %d, but max(pads[0], pads[1]) is %d.", + ksize[0], std::max(paddings[0], paddings[1]))); + PADDLE_ENFORCE_LT( + std::max(paddings[2], paddings[3]), ksize[1], + platform::errors::InvalidArgument( + "Paddings should be less than %d, but max(pads[2], pads[3]) is %d.", + ksize[1], std::max(paddings[2], paddings[3]))); + + if (adaptive || (global_pooling && pooling_type == "max")) { + PADDLE_ENFORCE_EQ(data_dims[0] % out_data_dims[0], 0, + platform::errors::InvalidArgument( + "When adaptive = True, H and W must be divisible, " + "but input dims is %s, output dims is %s", + data_dims, out_data_dims)); + PADDLE_ENFORCE_EQ(data_dims[1] % out_data_dims[1], 0, + platform::errors::InvalidArgument( + "When adaptive = True, H and W must be divisible, " + "but input dims is %s, output dims is %s", + data_dims, out_data_dims)); + if (channel_last) { + strides_vec[1] = data_dims[0] / out_data_dims[0]; + strides_vec[2] = data_dims[1] / out_data_dims[1]; + ksize_vec[1] = strides_vec[1]; + ksize_vec[2] = strides_vec[2]; + } else { + strides_vec[2] = data_dims[0] / out_data_dims[0]; + strides_vec[3] = data_dims[1] / out_data_dims[1]; + ksize_vec[2] = strides_vec[2]; + ksize_vec[3] = strides_vec[3]; + } + } + + NPUAttributeMap attrs = {{"ksize", ksize_vec}, + {"strides", strides_vec}, + {"padding_mode", std::string("CALCULATED")}, + {"pads", paddings}, + {"data_format", data_format}, + {"global_pooling", global_pooling}, + {"ceil_mode", ceil_mode}, + {"exclusive", exclusive}}; + + if (pooling_type == "max") { + if (global_pooling) { + for (auto &s : strides_vec) { + s = 1; + } + PADDLE_ENFORCE_LT(std::max(data_dims[0], data_dims[1]), 255, + platform::errors::InvalidArgument( + "MaxPoolGrad H, W must be less than 255 when " + "global_pooling = True, but got %s", + data_dims)); + attrs["global_pooling"] = false; + } + + const auto &runner = NpuOpRunner( + "MaxPoolV3Grad", {in_x_tensor, out_tensor, out_grad_tensor}, + {in_x_grad_tensor}, attrs); // 0: floor, 1: ceil + runner.Run(dev_ctx.stream()); + } else if (pooling_type == "avg") { + PADDLE_ENFORCE(strides[0] == strides[1], + platform::errors::InvalidArgument( + "AvgPoolGrad dose not support Asymmetric strides. but " + "strides = (%d, %d)", + strides[0], strides[1])); + + NpuOpRunner runner; + runner.SetType("AvgPoolV2Grad"); + runner.AddInput(framework::vectorize(in_x->dims())); + runner.AddInput(out_grad_tensor); + runner.AddOutput(in_x_grad_tensor); + runner.AddAttrs(attrs); + runner.Run(dev_ctx.stream()); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_NPU_KERNEL(pool2d, ops::NPUPoolOpKernel, + ops::NPUPoolOpKernel); +REGISTER_OP_NPU_KERNEL(pool2d_grad, ops::NPUPoolGradOpKernel, + ops::NPUPoolGradOpKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_pool2d_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_pool2d_op_npu.py new file mode 100644 index 00000000000..2b8550a88de --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_pool2d_op_npu.py @@ -0,0 +1,686 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +import unittest +import numpy as np +sys.path.append("..") + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from op_test import OpTest +from test_pool2d_op import pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive, adaptive_start_index, adaptive_end_index +from paddle.nn.functional import avg_pool2d, max_pool2d + +paddle.enable_static() + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.paddings = [0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_use_ceil_class(parent): + class TestPool2DUseCeilCase(parent): + def init_ceil_mode(self): + self.ceil_mode = True + + cls_name = "{0}_{1}".format(parent.__name__, "CeilModeCast") + TestPool2DUseCeilCase.__name__ = cls_name + globals()[cls_name] = TestPool2DUseCeilCase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.paddings = [1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +def create_test_fp16_class(parent): + class TestFp16Case(parent): + def init_kernel_type(self): + self.use_cudnn = False + self.dtype = np.float16 + + def test_check_grad(self): + return + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + +def pool2d_backward_navie(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=True, + adaptive=False, + data_format='NCHW', + pool_type="max", + padding_algorithm="EXPLICIT"): + # update paddings + def _get_padding_with_SAME(input_shape, pool_size, pool_stride): + padding = [] + for input_size, filter_size, stride_size in zip(input_shape, pool_size, + pool_stride): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max(( + (out_size - 1) * stride_size + filter_size - input_size, 0)) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + if isinstance(padding_algorithm, str): + padding_algorithm = padding_algorithm.upper() + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError("Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % + str(padding_algorithm)) + + if padding_algorithm == "VALID": + paddings = [0, 0, 0, 0] + if ceil_mode != False: + raise ValueError( + "When Attr(pool_padding) is \"VALID\", Attr(ceil_mode)" + " must be False. " + "Received ceil_mode: True.") + elif padding_algorithm == "SAME": + input_data_shape = [] + if data_format == "NCHW": + input_data_shape = x.shape[2:4] + elif data_format == "NHWC": + input_data_shape = x.shape[1:3] + paddings = _get_padding_with_SAME(input_data_shape, ksize, strides) + + assert len(paddings) == 2 or len(paddings) == 4 + is_sys = True if len(paddings) == 2 else False + + if data_format == "NHWC": + x = x.transpose([0, 3, 1, 2]) + + N, C, H, W = x.shape + + if global_pool == 1: + ksize = [H, W] + paddings = [0 for _ in range(len(paddings))] + + pad_h_up = paddings[0] if is_sys else paddings[0] + pad_h_down = paddings[0] if is_sys else paddings[1] + pad_w_left = paddings[1] if is_sys else paddings[2] + pad_w_right = paddings[1] if is_sys else paddings[3] + + if adaptive: + H_out, W_out = ksize + else: + H_out = (H - ksize[0] + pad_h_up + pad_h_down + strides[0] - 1) // strides[0] + 1 \ + if ceil_mode else (H - ksize[0] + pad_h_up + pad_h_down) // strides[0] + 1 + W_out = (W - ksize[1] + pad_w_left + pad_w_right + strides[1] - 1) // strides[1] + 1 \ + if ceil_mode else (W - ksize[1] + pad_w_left + pad_w_right) // strides[1] + 1 + + x_grad = np.zeros_like(x) + for i in range(H_out): + if adaptive: + in_h_start = adaptive_start_index(i, H, ksize[0]) + in_h_end = adaptive_end_index(i, H, ksize[0]) + else: + in_h_start = np.max((i * strides[0] - pad_h_up, 0)) + in_h_end = np.min((i * strides[0] + ksize[0] - pad_h_up, H)) + + for j in range(W_out): + if adaptive: + in_w_start = adaptive_start_index(j, W, ksize[1]) + in_w_end = adaptive_end_index(j, W, ksize[1]) + else: + in_h_start = i * strides[0] - pad_h_up + in_w_start = j * strides[1] - pad_w_left + in_h_end = i * strides[0] + ksize[0] - pad_h_up + in_w_end = j * strides[1] + ksize[1] - pad_w_left + + field_size = (in_h_end - in_h_start) * (in_w_end - in_w_start) + in_h_start = np.max((in_h_start, 0)) + in_w_start = np.max((in_w_start, 0)) + in_h_end = np.min((in_h_end, H)) + in_w_end = np.min((in_w_end, W)) + + if pool_type == 'avg': + if (exclusive or adaptive): + field_size = (in_h_end - in_h_start) * ( + in_w_end - in_w_start) + x_grad[:, :, in_h_start:in_h_end, in_w_start: + in_w_end] += 1 / field_size + elif pool_type == 'max': + for n in range(N): + for c in range(C): + idx = np.argmax(x[n, c, in_h_start:in_h_end, in_w_start: + in_w_end].flatten()) + idx_h = idx // (in_w_end - in_w_start) + idx_w = idx % (in_w_end - in_w_start) + x_grad[n, c, in_h_start + idx_h, in_w_start + + idx_w] += 1 + + if data_format == "NHWC": + x_grad = x_grad.transpose([0, 2, 3, 1]) + return x_grad + + +class TestPool2D_Op(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "pool2d" + self.init_kernel_type() + self.init_data_type() + self.init_test_case() + self.padding_algorithm = "EXPLICIT" + self.init_paddings() + self.init_global_pool() + self.init_kernel_type() + self.init_pool_type() + self.init_ceil_mode() + self.init_exclusive() + self.init_adaptive() + self.init_data_format() + self.init_shape() + + input = np.random.random(self.shape).astype(self.dtype) + if self.pool_type == "max": + input = np.array([x for x in range(np.prod(self.shape))]).reshape( + self.shape).astype(self.dtype) + output = pool2D_forward_naive( + input, self.ksize, self.strides, self.paddings, self.global_pool, + self.ceil_mode, self.exclusive, self.adaptive, self.data_format, + self.pool_type, self.padding_algorithm).astype(self.dtype) + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)} + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'pooling_type': self.pool_type, + 'global_pooling': self.global_pool, + 'use_cudnn': False, + 'use_mkldnn': False, + 'ceil_mode': self.ceil_mode, + 'data_format': self.data_format, + 'exclusive': self.exclusive, + 'adaptive': self.adaptive, + "padding_algorithm": self.padding_algorithm, + } + + self.outputs = {'Out': output} + + def init_data_format(self): + self.data_format = "NCHW" + + def init_shape(self): + self.shape = [2, 3, 5, 5] + + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + + def init_paddings(self): + self.paddings = [0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_kernel_type(self): + self.use_cudnn = False + + def init_data_type(self): + self.dtype = np.float32 + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = True + + def init_ceil_mode(self): + self.ceil_mode = False + + def init_exclusive(self): + self.exclusive = True + + def init_adaptive(self): + self.adaptive = False + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(fluid.NPUPlace(0), atol=1e-3) + + def test_check_grad(self): + x_grad = pool2d_backward_navie( + self.inputs["X"], + ksize=self.ksize, + strides=self.strides, + paddings=self.paddings, + global_pool=self.global_pool, + ceil_mode=False, + exclusive=self.exclusive, + adaptive=self.adaptive, + data_format=self.data_format, + pool_type=self.pool_type, + padding_algorithm=self.padding_algorithm) + x_grad = x_grad / np.prod(self.outputs['Out'].shape) + self.check_grad_with_place( + fluid.NPUPlace(0), + set(['X']), + 'Out', + max_relative_error=0.06, + user_defined_grads=[x_grad]) + + +class TestCase1(TestPool2D_Op): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + + def init_paddings(self): + self.paddings = [0, 0] + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase2(TestPool2D_Op): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + + def init_paddings(self): + self.paddings = [1, 1] + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase3(TestPool2D_Op): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestCase4(TestCase1): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestCase5(TestCase2): + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + +class TestAvgInclude(TestCase2): + def init_exclusive(self): + self.exclusive = False + + +class TestAvgPoolAdaptive(TestCase1): + def init_adaptive(self): + self.adaptive = True + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + def init_test_case(self): + self.ksize = [7, 7] + self.strides = [7, 7] + self.paddings = [0, 0, 0, 0] + + +class TestAvgPoolAdaptiveAsyOutSize(TestCase1): + def init_adaptive(self): + self.adaptive = True + + def init_shape(self): + self.shape = [2, 3, 8, 8] + + def init_test_case(self): + self.ksize = [2, 4] + # fixme: CANN AvgPoolGradV3 dose not support asymmetric strides + # self.strides = [2, 4] + self.strides = [4, 4] + self.paddings = [0, 0, 0, 0] + + +#-------test pool2d with asymmetric padding----- +class TestPool2D_AsyPadding(TestPool2D_Op): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 5, 5] + + +class TestCase1_AsyPadding(TestCase1): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 0] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase2_AsyPadding(TestCase2): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 2, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase3_AsyPadding(TestCase3): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 5, 5] + + +class TestCase4_AsyPadding(TestCase4): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 0, 1, 0] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestCase5_AsyPadding((TestCase5)): + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [2, 2, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestAvgInclude_AsyPadding(TestCase2): + def init_exclusive(self): + self.exclusive = False + + def init_test_case(self): + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 2, 1, 2] + + def init_shape(self): + self.shape = [2, 3, 7, 7] + + +class TestAvgPoolAdaptive_AsyPadding(TestCase1): + def init_adaptive(self): + self.adaptive = True + + def init_test_case(self): + self.ksize = [2, 2] + self.strides = [2, 2] + self.paddings = [1, 1, 0, 2] + + def init_shape(self): + self.shape = [2, 3, 8, 8] + + +#----------- test channel_last -------------- +class TestPool2D_channel_last(TestPool2D_Op): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 5, 5, 3] + + +class TestCase1_channel_last(TestCase1): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase2_channel_last(TestCase2): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase3_channel_last(TestCase3): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 5, 5, 3] + + +class TestCase4_channel_last(TestCase4): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase5_channel_last(TestCase5): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase5_Max(TestCase2): + def init_pool_type(self): + self.pool_type = "max" + + +class TestCase5_channel_last_Max(TestCase5_Max): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestAvgInclude_channel_last(TestCase2_channel_last): + def init_exclusive(self): + self.exclusive = False + + +class TestAvgPoolAdaptive_channel_last(TestCase1_channel_last): + def init_adaptive(self): + self.adaptive = True + + def init_shape(self): + self.shape = [2, 8, 8, 3] + + def init_test_case(self): + self.ksize = [2, 2] + self.strides = [2, 2] + + +class TestPool2D_AsyPadding_channel_last(TestPool2D_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 5, 5, 3] + + +class TestCase1_AsyPadding_channel_last(TestCase1_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase2_AsyPadding_channel_last(TestCase2_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase3_AsyPadding_channel_last(TestCase3_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 5, 5, 3] + + +class TestCase4_AsyPadding_channel_last(TestCase4_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestCase5_AsyPadding_channel_last(TestCase5_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestAvgInclude_AsyPadding_channel_last(TestAvgInclude_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 7, 7, 3] + + +class TestAvgPoolAdaptive_AsyPadding_channel_last( + TestAvgPoolAdaptive_AsyPadding): + def init_data_format(self): + self.data_format = "NHWC" + + def init_shape(self): + self.shape = [2, 8, 8, 3] + + +class TestCase1_strides(TestCase1): + def init_test_case(self): + self.ksize = [3, 3] + # fixme: CANN AvgPoolGradV3 dose not support asymmetric strides + # self.strides = [1, 2] + self.strides = [2, 2] + + def init_shape(self): + self.shape = [2, 3, 4, 5] + + +create_test_padding_SAME_class(TestPool2D_Op) +create_test_padding_SAME_class(TestCase1) +create_test_padding_SAME_class(TestCase2) +create_test_padding_SAME_class(TestCase3) +create_test_padding_SAME_class(TestCase4) +create_test_padding_SAME_class(TestCase5) +create_test_padding_SAME_class(TestPool2D_channel_last) +create_test_padding_SAME_class(TestCase1_channel_last) +create_test_padding_SAME_class(TestCase2_channel_last) +create_test_padding_SAME_class(TestCase3_channel_last) +create_test_padding_SAME_class(TestCase4_channel_last) +create_test_padding_SAME_class(TestCase5_channel_last) +create_test_padding_SAME_class(TestCase1_strides) + +create_test_padding_VALID_class(TestPool2D_Op) +create_test_padding_VALID_class(TestCase1) +create_test_padding_VALID_class(TestCase2) +create_test_padding_VALID_class(TestCase3) +create_test_padding_VALID_class(TestCase4) +create_test_padding_VALID_class(TestCase5) +create_test_padding_VALID_class(TestPool2D_channel_last) +create_test_padding_VALID_class(TestCase1_channel_last) +create_test_padding_VALID_class(TestCase2_channel_last) +create_test_padding_VALID_class(TestCase3_channel_last) +create_test_padding_VALID_class(TestCase4_channel_last) +create_test_padding_VALID_class(TestCase5_channel_last) + +create_test_use_ceil_class(TestCase1) +create_test_use_ceil_class(TestCase2) +create_test_use_ceil_class(TestCase1_AsyPadding) +create_test_use_ceil_class(TestCase2_AsyPadding) +create_test_use_ceil_class(TestCase1_channel_last) +create_test_use_ceil_class(TestCase2_channel_last) +create_test_use_ceil_class(TestCase1_AsyPadding_channel_last) +create_test_use_ceil_class(TestCase2_AsyPadding_channel_last) + +create_test_fp16_class(TestPool2D_Op) +create_test_fp16_class(TestCase1) +create_test_fp16_class(TestCase2) +create_test_fp16_class(TestCase3) +create_test_fp16_class(TestCase4) +create_test_fp16_class(TestCase5) +create_test_fp16_class(TestPool2D_channel_last) +create_test_fp16_class(TestCase1_channel_last) +create_test_fp16_class(TestCase2_channel_last) +create_test_fp16_class(TestCase3_channel_last) +create_test_fp16_class(TestCase4_channel_last) +create_test_fp16_class(TestCase5_channel_last) + +if __name__ == "__main__": + unittest.main() -- GitLab