From db0ea0ce70fd9f701a17052255228bb3b1284682 Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Wed, 23 Nov 2022 18:38:21 +0800 Subject: [PATCH] add masked_select_grad kernel (#48137) * add stat tool * add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun * add masked_selected_grad kernel,test=kunlun --- .../fluid/platform/device/xpu/xpu2_op_list.h | 15 +- paddle/phi/kernels/xpu/concat_kernel.cc | 13 +- paddle/phi/kernels/xpu/conv_kernel.cc | 87 ++ .../kernels/xpu/masked_select_grad_kernel.cc | 57 ++ paddle/phi/kernels/xpu/sgd_kernel.cc | 102 ++- .../tests/unittests/xpu/test_conv3d_op_xpu.py | 742 ++++++++++++++++++ .../xpu/test_masked_select_op_xpu.py | 3 + .../tests/unittests/xpu/test_sgd_op_xpu.py | 73 ++ 8 files changed, 1071 insertions(+), 21 deletions(-) create mode 100644 paddle/phi/kernels/xpu/masked_select_grad_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 4862401f83a..62a4daf7275 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -123,13 +123,17 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"conv3d", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"conv2d_transpose_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_transpose", @@ -375,6 +379,12 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"masked_select_grad", + XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, @@ -502,6 +512,9 @@ XPUOpMap& get_kl2_ops() { {"sgd", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"sgd_dense_param_sparse_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"sigmoid_cross_entropy_with_logits_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sigmoid_cross_entropy_with_logits", diff --git a/paddle/phi/kernels/xpu/concat_kernel.cc b/paddle/phi/kernels/xpu/concat_kernel.cc index 50b323429b0..4e09f6ef852 100644 --- a/paddle/phi/kernels/xpu/concat_kernel.cc +++ b/paddle/phi/kernels/xpu/concat_kernel.cc @@ -50,6 +50,7 @@ void ConcatKernel(const Context& dev_ctx, x[0]->dims().size())); // If axis is 0, the lod of the output is not the same as inputs. + if (axis == 0 && x[0]->lod().size() > 0) { size_t lod_size_0 = x[0]->lod().size(); size_t lod_size = lod_size_0; @@ -79,7 +80,9 @@ void ConcatKernel(const Context& dev_ctx, } } } + dev_ctx.template Alloc(out); + std::vector> xdims_list; std::vector ptrs; for (unsigned int i = 0; i < x.size(); ++i) { @@ -97,6 +100,7 @@ void ConcatKernel(const Context& dev_ctx, PADDLE_ENFORCE_GT(xdims_list.size(), 0, phi::errors::InvalidArgument("No tensor need concat")); + int r = xpu::concat(dev_ctx.x_context(), ptrs, reinterpret_cast(out->data()), @@ -107,5 +111,10 @@ void ConcatKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - concat, XPU, ALL_LAYOUT, phi::ConcatKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL(concat, + XPU, + ALL_LAYOUT, + phi::ConcatKernel, + float, + phi::dtype::float16, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/conv_kernel.cc b/paddle/phi/kernels/xpu/conv_kernel.cc index 8bbbdc2c16d..bca16c84a90 100644 --- a/paddle/phi/kernels/xpu/conv_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_kernel.cc @@ -131,9 +131,96 @@ void DepthwiseConvKernel(const Context& dev_ctx, out); } +template +void Conv3DKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const std::vector& strides, + const std::vector& paddings_t, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations_t, + const std::string& data_format, + DenseTensor* out) { + using XPUT = typename XPUTypeTrait::Type; + std::vector paddings = paddings_t; + std::vector dilations = dilations_t; + // The filter will be reshaped in the calculations, + // so here use an assignment operation, + // that avoids modifying the variable in the Scope. + dev_ctx.template Alloc(out); + + phi::DDim in_data_dims = + phi::slice_ddim(input.dims(), 2, input.dims().size()); + phi::DDim filter_data_dims = + phi::slice_ddim(filter.dims(), 2, filter.dims().size()); + std::vector ksize = phi::vectorize(filter_data_dims); + UpdatePaddingAndDilation( + &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); + + int batch_size = static_cast(input.dims()[0]); + int img_c = static_cast(input.dims()[1]); + int img_d = static_cast(input.dims()[2]); + int img_h = static_cast(input.dims()[3]); + int img_w = static_cast(input.dims()[4]); + int f = static_cast(filter.dims()[0]); + bool is_ncdhw = true; + if (data_format == "NDHWC") { + img_c = static_cast(input.dims()[4]); + img_d = static_cast(input.dims()[1]); + img_h = static_cast(input.dims()[2]); + img_w = static_cast(input.dims()[3]); + is_ncdhw = false; + } + + XPUT* output_data = reinterpret_cast(out->data()); + const XPUT* filter_data = reinterpret_cast(filter.data()); + const XPUT* input_data = reinterpret_cast(input.data()); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + + XPUT* filter_data_tmp; + const XPUT* filter_data_ptr = filter_data; + if (data_format == "NDHWC") { + filter_data_tmp = RAII_GUARD.alloc(filter.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); + std::vector filter_shape = phi::vectorize(filter.dims()); + int r = xpu::transpose(dev_ctx.x_context(), + filter_data, + filter_data_tmp, + filter_shape, + {0, 2, 3, 4, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + filter_data_ptr = reinterpret_cast(filter_data_tmp); + } + + int r = xpu::conv3d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_ncdhw); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d"); +} + } // namespace phi PD_REGISTER_KERNEL( conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL( depthwise_conv2d, XPU, ALL_LAYOUT, phi::DepthwiseConvKernel, float) {} +PD_REGISTER_KERNEL( + conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc b/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc new file mode 100644 index 00000000000..52a98c63f48 --- /dev/null +++ b/paddle/phi/kernels/xpu/masked_select_grad_kernel.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/masked_select_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MaskedSelectGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& mask, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + auto* mask_data = mask.data(); + auto* input_data = reinterpret_cast(out_grad.data()); + auto* out_data = + reinterpret_cast(dev_ctx.template Alloc(x_grad)); + + auto mask_shape = phi::vectorize(mask.dims()); + auto xshape = phi::vectorize(x_grad->dims()); + + int r = xpu::masked_select_grad(dev_ctx.x_context(), + input_data, + mask_data, + out_data, + xshape, + mask_shape, + 1); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "masked_select_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(masked_select_grad, + XPU, + ALL_LAYOUT, + phi::MaskedSelectGradKernel, + float, + phi::dtype::float16, + int, + bool, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/sgd_kernel.cc b/paddle/phi/kernels/xpu/sgd_kernel.cc index 510fddae3ba..1f821a8de28 100644 --- a/paddle/phi/kernels/xpu/sgd_kernel.cc +++ b/paddle/phi/kernels/xpu/sgd_kernel.cc @@ -20,14 +20,14 @@ namespace phi { template -void SGDDenseKernel(const Context &dev_ctx, - const DenseTensor ¶m, - const DenseTensor &learning_rate, - const DenseTensor &grad, - const paddle::optional &master_param, +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + const paddle::optional& master_param, bool multi_precision, - DenseTensor *param_out, - DenseTensor *master_param_out) { + DenseTensor* param_out, + DenseTensor* master_param_out) { using XPUType = typename XPUTypeTrait::Type; auto sz = param_out->numel(); PADDLE_ENFORCE_EQ( @@ -49,37 +49,103 @@ void SGDDenseKernel(const Context &dev_ctx, grad.numel(), sz)); - const T *lr_t = learning_rate.data(); + const T* lr_t = learning_rate.data(); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - const float *lr = nullptr; + const float* lr = nullptr; if (std::is_same::value) { - float *lr_float = RAII_GUARD.alloc_l3_or_gm(learning_rate.numel()); + float* lr_float = RAII_GUARD.alloc_l3_or_gm(learning_rate.numel()); int r = xpu::cast(dev_ctx.x_context(), - reinterpret_cast(lr_t), + reinterpret_cast(lr_t), lr_float, learning_rate.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); lr = lr_float; } else { - lr = reinterpret_cast(lr_t); + lr = reinterpret_cast(lr_t); } - const T *param_data = param.data(); - const T *grad_data = grad.data(); + const T* param_data = param.data(); + const T* grad_data = grad.data(); dev_ctx.template Alloc(param_out); - T *out_data = param_out->data(); + T* out_data = param_out->data(); int r = xpu::sgd(dev_ctx.x_context(), - reinterpret_cast(grad_data), - reinterpret_cast(param_data), + reinterpret_cast(grad_data), + reinterpret_cast(param_data), lr, - reinterpret_cast(out_data), + reinterpret_cast(out_data), sz); PADDLE_ENFORCE_XDNN_SUCCESS(r, "sgd"); } +template +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(param_out); + + PADDLE_ENFORCE_EQ( + ¶m, + param_out, + phi::errors::InvalidArgument( + "The input tensor Param of SgdOp should be equal with ParamOut " + "if variable's type is SelectedRows.")); + + auto in_height = grad.height(); + auto out_dims = param_out->dims(); + PADDLE_ENFORCE_EQ(in_height, + out_dims[0], + phi::errors::InvalidArgument( + "The input tensor Grad's height of SgdOp should be " + "equal with ParamOut's dims. But received Grad's " + "height [%s] and ParamOut's dims [%s]", + in_height, + out_dims[0])); + + auto& in_value = grad.value(); + auto& in_rows = grad.rows(); + int64_t* in_rows_data = nullptr; + xpu::VectorParam in_rows_vec{ + in_rows.data(), static_cast(in_rows.size()), in_rows_data}; + + int64_t in_row_numel = in_value.numel() / in_rows.size(); + PADDLE_ENFORCE_EQ(in_row_numel, + param_out->numel() / in_height, + phi::errors::InvalidArgument( + "The in_row_numel of SgdOp should be equal with " + "param_out's numel / in_height.")); + + auto* in_data = in_value.data(); + auto* out_data = param_out->data(); + + int r = xpu::sparse_sgd( + dev_ctx.x_context(), + reinterpret_cast(in_data), + reinterpret_cast(param.data()), + learning_rate.data(), + in_rows_vec, + reinterpret_cast(out_data), + in_row_numel, + in_rows.size()); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sparse_sgd"); +} + } // namespace phi PD_REGISTER_KERNEL( sgd, XPU, ALL_LAYOUT, phi::SGDDenseKernel, phi::dtype::float16, float) {} +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + XPU, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + phi::dtype::float16, + float) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py new file mode 100644 index 00000000000..f949d7eeef8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_conv3d_op_xpu.py @@ -0,0 +1,742 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append("..") +import unittest +import numpy as np + +from op_test_xpu import XPUOpTest +import paddle.fluid as fluid +import paddle +from xpu.get_test_cover_info import ( + create_test_class, + XPUOpTestWrapper, +) + + +def conv3d_forward_naive( + input, + filter, + group, + conv_param, + padding_algorithm='EXPLICIT', + data_format="NCDHW", +): + + if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]: + raise ValueError( + "Unknown Attr(padding_algorithm): '%s'. " + "It can only be 'SAME' or 'VALID'." % str(padding_algorithm) + ) + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError( + "Unknown Attr(data_format): '%s' ." + "It can only be 'NCDHW' or 'NDHWC'." % str(data_format) + ) + + channel_last = data_format == "NDHWC" + if channel_last: + input = np.transpose(input, [0, 4, 1, 2, 3]) + + in_n, in_c, in_d, in_h, in_w = input.shape + + f_n, f_c, f_d, f_h, f_w = filter.shape + out_n = in_n + out_c = f_n + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + sub_out_c = out_c // group + sub_f_n = f_n // group + + stride, pad, dilation = ( + conv_param['stride'], + conv_param['pad'], + conv_param['dilations'], + ) + + # update pad and dilation + def _get_padding_with_SAME(input_shape, pool_size, pool_stride): + padding = [] + for input_size, filter_size, stride_size in zip( + input_shape, pool_size, pool_stride + ): + out_size = int((input_size + stride_size - 1) / stride_size) + pad_sum = np.max( + ((out_size - 1) * stride_size + filter_size - input_size, 0) + ) + pad_0 = int(pad_sum / 2) + pad_1 = int(pad_sum - pad_0) + padding.append(pad_0) + padding.append(pad_1) + return padding + + ksize = filter.shape[2:5] + if padding_algorithm == "VALID": + pad = [0, 0, 0, 0, 0, 0] + elif padding_algorithm == "SAME": + dilation = [1, 1, 1] + input_data_shape = input.shape[2:5] + pad = _get_padding_with_SAME(input_data_shape, ksize, stride) + + pad_d_0, pad_d_1 = pad[0], pad[0] + pad_h_0, pad_h_1 = pad[1], pad[1] + pad_w_0, pad_w_1 = pad[2], pad[2] + if len(pad) == 6: + pad_d_0, pad_d_1 = pad[0], pad[1] + pad_h_0, pad_h_1 = pad[2], pad[3] + pad_w_0, pad_w_1 = pad[4], pad[5] + + out_d = ( + 1 + + (in_d + pad_d_0 + pad_d_1 - (dilation[0] * (f_d - 1) + 1)) + // stride[0] + ) + out_h = ( + 1 + + (in_h + pad_h_0 + pad_h_1 - (dilation[1] * (f_h - 1) + 1)) + // stride[1] + ) + out_w = ( + 1 + + (in_w + pad_w_0 + pad_w_1 - (dilation[2] * (f_w - 1) + 1)) + // stride[2] + ) + + out = np.zeros((in_n, out_c, out_d, out_h, out_w)) + + d_bolck_d = dilation[0] * (f_d - 1) + 1 + d_bolck_h = dilation[1] * (f_h - 1) + 1 + d_bolck_w = dilation[2] * (f_w - 1) + 1 + + input_pad = np.pad( + input, + ( + (0, 0), + (0, 0), + (pad_d_0, pad_d_1), + (pad_h_0, pad_h_1), + (pad_w_0, pad_w_1), + ), + mode='constant', + constant_values=0, + ) + + filter_dilation = np.zeros((f_n, f_c, d_bolck_d, d_bolck_h, d_bolck_w)) + filter_dilation[ + :, + :, + 0 : d_bolck_d : dilation[0], + 0 : d_bolck_h : dilation[1], + 0 : d_bolck_w : dilation[2], + ] = filter + + for d in range(out_d): + for i in range(out_h): + for j in range(out_w): + for g in range(group): + input_pad_masked = input_pad[ + :, + g * f_c : (g + 1) * f_c, + d * stride[0] : d * stride[0] + d_bolck_d, + i * stride[1] : i * stride[1] + d_bolck_h, + j * stride[2] : j * stride[2] + d_bolck_w, + ] + + f_sub = filter_dilation[ + g * sub_f_n : (g + 1) * sub_f_n, :, :, :, : + ] + for k in range(sub_out_c): + out[:, g * sub_out_c + k, d, i, j] = np.sum( + input_pad_masked * f_sub[k, :, :, :, :], + axis=(1, 2, 3, 4), + ) + if channel_last: + out = np.transpose(out, [0, 2, 3, 4, 1]) + return out + + +def create_test_padding_SAME_class(parent): + class TestPaddingSMAECase(parent): + def init_paddings(self): + self.pad = [0, 0, 0] + self.padding_algorithm = "SAME" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingSAMEOp") + TestPaddingSMAECase.__name__ = cls_name + globals()[cls_name] = TestPaddingSMAECase + + +def create_test_padding_VALID_class(parent): + class TestPaddingVALIDCase(parent): + def init_paddings(self): + self.pad = [1, 1, 1] + self.padding_algorithm = "VALID" + + cls_name = "{0}_{1}".format(parent.__name__, "PaddingVALIDOp") + TestPaddingVALIDCase.__name__ = cls_name + globals()[cls_name] = TestPaddingVALIDCase + + +def create_test_channel_last_class(parent): + class TestChannelLastCase(parent): + def init_data_format(self): + self.data_format = "NDHWC" + + def init_test_case_2(self): + N, C, D, H, W = self.input_size + self.input_size = [N, D, H, W, C] + + cls_name = "{0}_{1}".format(parent.__name__, "ChannelLast") + TestChannelLastCase.__name__ = cls_name + globals()[cls_name] = TestChannelLastCase + + +paddle.enable_static() + + +class XPUTestConv3DOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'conv3d' + self.use_dynamic_create_class = False + + class TestConv3DOp(XPUOpTest): + def setUp(self): + self.dtype = self.in_type + self.op_type = "conv3d" + self.use_cudnn = False + self.use_mkldnn = False + self.data_format = "AnyLayout" + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_test_case() + + conv3d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilations': self.dilations, + } + + np.random.seed(100) + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + output = conv3d_forward_naive( + input, + filter, + self.groups, + conv3d_param, + ).astype(self.dtype) + + self.inputs = { + 'Input': XPUOpTest.np_dtype_to_fluid_dtype(input), + 'Filter': XPUOpTest.np_dtype_to_fluid_dtype(filter), + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_test_case_2(self): + pass + + def init_dilation(self): + self.dilations = [1, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + class TestCase1(TestConv3DOp): + def init_test_case(self): + self.pad = [1, 1, 1] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + class TestWithGroup1(TestConv3DOp): + def init_group(self): + self.groups = 3 + + class TestWithGroup2(TestCase1): + def init_group(self): + self.groups = 3 + + class TestWith1x1(TestConv3DOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1, 1] + + def init_group(self): + self.groups = 3 + + class TestWithInput1x1Filter1x1(TestConv3DOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [40, 3, 1, 1, 1] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [120, f_c, 1, 1, 1] + + def init_dilation(self): + self.dilations = [1, 1, 1] + + def init_group(self): + self.groups = 3 + + class TestWithDilation(TestConv3DOp): + def init_test_case(self): + self.pad = [0, 0, 0] + self.stride = [1, 1, 1] + self.input_size = [2, 3, 6, 6, 6] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 2, 2, 2] + + def init_dilation(self): + self.dilations = [2, 2, 2] + + def init_group(self): + self.groups = 3 + + +# ---- test asymmetric padding ---- +class XPUTestConv3DOp_v2(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'conv3d' + self.use_dynamic_create_class = False + + class TestConv3DOp_2(XPUOpTest): + def setUp(self): + self.dtype = self.in_type + self.op_type = "conv3d" + self.use_cudnn = False + self.use_mkldnn = False + self.data_format = "NCDHW" + self.init_kernel_type() + self.init_group() + self.init_dilation() + self.init_data_format() + self.init_test_case() + self.init_paddings() + + self.init_test_case_2() + + conv3d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilations': self.dilations, + } + + np.random.seed(100) + input = np.random.random(self.input_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + output = conv3d_forward_naive( + input, + filter, + self.groups, + conv3d_param, + self.padding_algorithm, + self.data_format, + ).astype(self.dtype) + + self.inputs = { + 'Input': XPUOpTest.np_dtype_to_fluid_dtype(input), + 'Filter': XPUOpTest.np_dtype_to_fluid_dtype(filter), + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'padding_algorithm': self.padding_algorithm, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + } + self.outputs = {'Output': output} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_test_case_2(self): + pass + + def init_dilation(self): + self.dilations = [1, 1, 1] + + def init_group(self): + self.groups = 1 + + def init_kernel_type(self): + pass + + def init_paddings(self): + self.pad = [0, 0, 0] + self.padding_algorithm = "EXPLICIT" + + def init_data_format(self): + self.data_format = "NCDHW" + + class TestConv3DOp_AsyPadding(TestConv3DOp_2): + def init_test_case(self): + self.stride = [1, 1, 2] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_paddings(self): + self.pad = [1, 0, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + class TestConv3DOp_DiffDataInDiffDim(TestConv3DOp_2): + def init_test_case(self): + self.stride = [1, 1, 2] + self.input_size = [2, 3, 4, 5, 5] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 4, 3] + + def init_paddings(self): + self.pad = [1, 0, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + class TestCase1_AsyPadding(TestConv3DOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_paddings(self): + self.pad = [0, 0, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + class TestWithGroup1_AsyPadding(TestConv3DOp_2): + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [1, 1, 1, 0, 0, 2] + self.padding_algorithm = "EXPLICIT" + + class TestWithGroup2_AsyPadding(TestConv3DOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 4, 4, 4] # NCDHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3, 3] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [1, 1, 0, 1, 0, 2] + self.padding_algorithm = "EXPLICIT" + + class TestWithDilation_AsyPadding(TestConv3DOp_2): + def init_test_case(self): + self.stride = [1, 1, 1] + self.input_size = [2, 3, 6, 6, 6] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [24, f_c, 2, 2, 2] + + def init_dilation(self): + self.dilations = [2, 2, 2] + + def init_group(self): + self.groups = 3 + + def init_paddings(self): + self.pad = [0, 0, 1, 0, 1, 0] + self.padding_algorithm = "EXPLICIT" + + +# --------- test python API --------------- +class TestConv3DAPI(unittest.TestCase): + def test_api(self): + + input_NDHWC = fluid.layers.data( + name="input_NDHWC", + shape=[2, 5, 5, 5, 3], + append_batch_size=False, + dtype="float32", + ) + + input_NCDHW = fluid.layers.data( + name="input_NCDHW", + shape=[2, 3, 5, 5, 3], + append_batch_size=False, + dtype="float32", + ) + + fluid.layers.conv3d( + input=input_NDHWC, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=0, + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW", + ) + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[1, 2, 1, 0, 1, 0], + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW", + ) + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[[0, 0], [0, 0], [1, 1], [1, 1], [1, 1]], + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW", + ) + + fluid.layers.conv3d( + input=input_NDHWC, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=[[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]], + dilation=[1, 1, 1], + groups=1, + data_format="NDHWC", + ) + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding="SAME", + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW", + ) + + fluid.layers.conv3d( + input=input_NCDHW, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding="VALID", + dilation=[1, 1, 1], + groups=1, + data_format="NCDHW", + ) + + +class TestConv3DAPI_Error(unittest.TestCase): + def test_api(self): + input = fluid.layers.data( + name="input", + shape=[2, 5, 5, 5, 4], + append_batch_size=False, + dtype="float32", + ) + + # ValueError: cudnn + def run_1(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding=0, + dilation=1, + groups=1, + use_cudnn=[0], + data_format="NCDHW", + ) + + self.assertRaises(ValueError, run_1) + + # ValueError: data_format + def run_2(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=[3, 3, 3], + stride=[1, 1, 1], + padding=0, + dilation=[1, 1, 1], + groups=1, + use_cudnn=False, + data_format="NCHWC", + ) + + self.assertRaises(ValueError, run_2) + + # ValueError: padding + def run_3(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding="SAMEE", + dilation=1, + groups=1, + use_cudnn=False, + data_format="NCDHW", + ) + + self.assertRaises(ValueError, run_3) + + def run_4(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding=[[0, 1], [0, 0], [0, 1], [0, 1], [0, 1]], + dilation=1, + groups=1, + use_cudnn=False, + data_format="NCDHW", + ) + + self.assertRaises(ValueError, run_4) + + def run_5(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=0, + stride=0, + padding=[[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]], + dilation=1, + groups=1, + use_cudnn=False, + data_format="NDHWC", + ) + + self.assertRaises(ValueError, run_5) + + # ValueError: channel dimmention + x = fluid.layers.data( + name="x", + shape=[2, 5, 5, 5, -1], + append_batch_size=False, + dtype="float32", + ) + + def run_6(): + fluid.layers.conv3d( + input=x, + num_filters=3, + filter_size=3, + stride=1, + padding=0, + dilation=1, + groups=1, + use_cudnn=False, + data_format="NDHWC", + ) + + self.assertRaises(ValueError, run_6) + + # ValueError: groups + def run_7(): + fluid.layers.conv3d( + input=input, + num_filters=3, + filter_size=3, + stride=1, + padding=0, + dilation=1, + groups=3, + use_cudnn=False, + data_format="NDHWC", + ) + + self.assertRaises(ValueError, run_7) + + # ValueError: filter num + def run_8(): + fluid.layers.conv3d( + input=input, + num_filters=0, + filter_size=0, + stride=0, + padding=0, + dilation=0, + groups=1, + use_cudnn=False, + data_format="NDHWC", + ) + + self.assertRaises(ValueError, run_8) + + +for stype in ["float32"]: + create_test_class(globals(), XPUTestConv3DOp, stype) + create_test_class(globals(), XPUTestConv3DOp_v2, stype) +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py index 6a2976ccbb5..f596f22dd49 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_masked_select_op_xpu.py @@ -58,6 +58,9 @@ class XPUTestMaskedSelectOp(XPUOpTestWrapper): def test_check_output(self): self.check_output_with_place(self.place) + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Y') + def init(self): self.shape = (50, 3) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py index 2c7bb941410..7929b0f3fc3 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_sgd_op_xpu.py @@ -19,6 +19,8 @@ import sys sys.path.append("..") import paddle import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.op import Operator from op_test_xpu import XPUOpTest from xpu.get_test_cover_info import ( @@ -83,6 +85,77 @@ class TestSGDOpWithLargeInput(unittest.TestCase): result = exe.run(fluid.default_main_program(), fetch_list=[avg_cost]) +class TestSparseSGDOp(unittest.TestCase): + def check_with_place(self, place): + scope = core.Scope() + + # create and initialize Grad Variable + height = 10 + rows = [0, 4, 7] + self.conf() + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), self.row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + # create and initialize Param Variable + param = scope.var('Param').get_tensor() + param_array = np.full((height, self.row_numel), 5.0).astype("float32") + param.set(param_array, place) + + # create and initialize LeraningRate Variable + lr = scope.var('LearningRate').get_tensor() + lr_array = np.full((1), 2.0).astype("float32") + lr.set(lr_array, place) + + # create and run sgd operator + sgd_op = Operator( + "sgd", + Param='Param', + Grad='Grad', + ParamOut='Param', + LearningRate='LearningRate', + ) + sgd_op.run(scope, place) + + # get and compare result + result_array = np.array(param) + + # rows[0] = 0, 5.0 - 2.0 * 2.0 + self.assertAlmostEqual(1.0, result_array[rows[0], 0]) + # rows[0] = 0, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[0], 2]) + # 5.0 - 2.0 * 0.0 + self.assertAlmostEqual(5.0, result_array[1, 0]) + # rows[1] = 4, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[1], 10]) + # 5.0 - 2.0 * 0.0 + self.assertAlmostEqual(5.0, result_array[5, 8]) + # rows[2] = 7, 5.0 - 2.0 * 1.0 + self.assertAlmostEqual(3.0, result_array[rows[2], 1]) + # rows[2] = 7, 5.0 - 2.0 * 4.0 + self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) + + def test_sparse_sgd(self): + places = [core.XPUPlace(0)] + for place in places: + self.check_with_place(place) + + def conf(self): + self.row_numel = 12 + + +class TestSparseSGDOpCase8X(TestSparseSGDOp): + def conf(self): + self.row_numel = 16 + + if __name__ == "__main__": paddle.enable_static() unittest.main() -- GitLab