From 02e9453fa277b58007674f559e6b7c45561a41e7 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Mon, 18 Jul 2022 15:02:11 +0800 Subject: [PATCH] add xpu resnet_unit (#44297) * add xpu resnet_unit *test=kunlun * tmp *test=kunlun --- cmake/external/xpu.cmake | 4 +- paddle/fluid/operators/fused/CMakeLists.txt | 1 + .../fluid/operators/fused/resnet_unit_op.cc | 26 +- .../operators/fused/resnet_unit_op_xpu.cc | 333 ++++++++++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 3 + .../unittests/xpu/get_test_cover_info.py | 4 +- .../paddle/incubate/operators/resnet_unit.py | 22 +- 7 files changed, 376 insertions(+), 17 deletions(-) create mode 100644 paddle/fluid/operators/fused/resnet_unit_op_xpu.cc diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c1f8eb0e33c..81128ccf3b6 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") if(NOT DEFINED XPU_BASE_URL) set(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220712") + set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220718") else() set(XPU_BASE_URL "${XPU_BASE_URL}") endif() @@ -19,7 +19,7 @@ endif() if(NOT DEFINED XPU_XDNN_BASE_URL) set(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") - set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220712") + set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220718") else() set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") endif() diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index dfbdaed8761..02a3f4d7a0e 100755 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -35,6 +35,7 @@ op_library(fusion_lstm_op) if(WITH_XPU) op_library(resnet_basic_block_op) + op_library(resnet_unit_op) endif() if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc index 4f4e0aa6ac2..5852a5c04bd 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cc +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -159,22 +159,28 @@ class ResNetUnitOp : public framework::OperatorWithKernel { bn_param_dims, bn_param_dims.size())); auto data_format = ctx->Attrs().Get("data_format"); - PADDLE_ENFORCE_EQ( - data_format, - "NHWC", - platform::errors::InvalidArgument("The data format must equal to NHWC. " - "But received: the data format " - "= [%s]", - data_format)); + bool is_nchw = (data_format == "NCHW"); // Calculate the dims of outputs int batch = x_dims[0]; int output_channel = w_dims[0]; int filter_size = w_dims[2]; int stride = ctx->Attrs().Get("stride"); int padding = ctx->Attrs().Get("padding"); - int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; - int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; - std::vector out_shape = {batch, out_h, out_w, output_channel}; + std::vector out_shape; + out_shape.push_back(batch); + if (is_nchw) { + int out_h = (x_dims[2] + padding * 2 - filter_size) / stride + 1; + int out_w = (x_dims[3] + padding * 2 - filter_size) / stride + 1; + out_shape.push_back(output_channel); + out_shape.push_back(out_h); + out_shape.push_back(out_w); + } else { + int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; + int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; + out_shape.push_back(out_h); + out_shape.push_back(out_w); + out_shape.push_back(output_channel); + } auto y_dims = phi::make_ddim(out_shape); auto bitmask_dims = GetBitmaskDims(out_shape); diff --git a/paddle/fluid/operators/fused/resnet_unit_op_xpu.cc b/paddle/fluid/operators/fused/resnet_unit_op_xpu.cc new file mode 100644 index 00000000000..cce506c67ab --- /dev/null +++ b/paddle/fluid/operators/fused/resnet_unit_op_xpu.cc @@ -0,0 +1,333 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ResNetUnitXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(place), + true, + platform::errors::PreconditionNotMet("It must use XPUPlace.")); + + bool is_nchw = (ctx.Attr("data_format") == "NCHW"); + // input x + const Tensor *input_x = ctx.Input("X"); + const Tensor *filter_x = ctx.Input("FilterX"); + const Tensor *scale_x = ctx.Input("ScaleX"); + const Tensor *bias_x = ctx.Input("BiasX"); + + // output x + Tensor *conv_out_x = ctx.Output("ConvX"); + Tensor *saved_mean_x = ctx.Output("SavedMeanX"); + Tensor *saved_invstd_x = ctx.Output("SavedInvstdX"); + Tensor *running_mean_x = ctx.Output("RunningMeanX"); + Tensor *running_var_x = ctx.Output("RunningVarX"); + + Tensor *output = ctx.Output("Y"); + + // attrs + int padding = ctx.Attr("padding"); + int stride = ctx.Attr("stride"); + int stride_z = ctx.Attr("stride_z"); + int dilation = ctx.Attr("dilation"); + int group = ctx.Attr("group"); + float eps = ctx.Attr("epsilon"); + float momentum = ctx.Attr("momentum"); + bool has_shortcut = ctx.Attr("has_shortcut"); + bool fuse_add = ctx.Attr("fuse_add"); + bool use_global_stats = ctx.Attr("use_global_stats"); + bool is_test = ctx.Attr("is_test"); + bool is_train = !is_test && !use_global_stats; + std::string act_type = ctx.Attr("act_type"); + auto &dev_ctx = ctx.template device_context(); + + std::vector x_list = {input_x->data()}; + std::vector w_list = {filter_x->data()}; + std::vector conv_y_list = {conv_out_x->mutable_data(place)}; + + std::vector> x_shape_list = { + phi::vectorize(input_x->dims())}; + + auto filter_x_shape = phi::vectorize(filter_x->dims()); + std::vector ksize = {filter_x_shape[2], filter_x_shape[3]}; + if (!is_nchw) { + ksize[0] = filter_x_shape[1]; + ksize[1] = filter_x_shape[2]; + } + std::vector strides = {stride, stride}; + std::vector> ksize_list = {ksize}; + std::vector> stride_list = {strides}; + std::vector paddings = {padding, padding}; + std::vector dilations = {dilation, dilation}; + std::vector scale_list = {scale_x->data()}; + std::vector bias_list = {bias_x->data()}; + std::vector batch_mean_list = { + saved_mean_x->mutable_data(place)}; + std::vector batch_invstd_list = { + saved_invstd_x->mutable_data(place)}; + std::vector global_mean_list = { + running_mean_x->mutable_data(place)}; + std::vector global_var_list = { + running_var_x->mutable_data(place)}; + + std::vector x_maxlist = {nullptr}; + std::vector w_maxlist = {nullptr}; + if (has_shortcut) { + // input z + const Tensor *input_z = ctx.Input("Z"); + const Tensor *filter_z = ctx.Input("FilterZ"); + const Tensor *scale_z = ctx.Input("ScaleZ"); + const Tensor *bias_z = ctx.Input("BiasZ"); + + Tensor *conv_out_z = ctx.Output("ConvZ"); + Tensor *saved_mean_z = ctx.Output("SavedMeanZ"); + Tensor *saved_invstd_z = ctx.Output("SavedInvstdZ"); + Tensor *running_mean_z = ctx.Output("RunningMeanZ"); + Tensor *running_var_z = ctx.Output("RunningVarZ"); + + x_list.push_back(input_z->data()); + w_list.push_back(filter_z->data()); + conv_y_list.push_back(conv_out_z->mutable_data(place)); + + x_shape_list.push_back(phi::vectorize(input_z->dims())); + + auto filter_z_shape = phi::vectorize(filter_z->dims()); + std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; + if (!is_nchw) { + ksize_z[0] = filter_z_shape[1]; + ksize_z[1] = filter_z_shape[2]; + } + ksize_list.push_back(ksize_z); + stride_list.push_back({stride_z, stride_z}); + scale_list.push_back(scale_z->data()); + bias_list.push_back(bias_z->data()); + batch_mean_list.push_back(saved_mean_z->mutable_data(place)); + batch_invstd_list.push_back(saved_invstd_z->mutable_data(place)); + global_mean_list.push_back(running_mean_z->mutable_data(place)); + global_var_list.push_back(running_var_z->mutable_data(place)); + x_maxlist.push_back(nullptr); + w_maxlist.push_back(nullptr); + } else { + if (fuse_add) { + const Tensor *input_z = ctx.Input("Z"); + auto input_z_shape = phi::vectorize(input_z->dims()); + x_list.push_back(input_z->data()); + x_shape_list.push_back(input_z_shape); + x_maxlist.push_back(nullptr); + } + } + int r = xpu::resnet_unit_fusion( + dev_ctx.x_context(), + x_list, + w_list, + conv_y_list, + output->mutable_data(place), + x_shape_list, + filter_x_shape[0], + ksize_list, + stride_list, + paddings, + dilations, + group, + eps, + momentum, + x_maxlist, + w_maxlist, + scale_list, + bias_list, + batch_mean_list, + batch_invstd_list, + global_mean_list, + global_var_list, + xpu::Activation_t::RELU, + is_nchw, + has_shortcut, + fuse_add, + is_train); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_fusion"); + } +}; + +template +class ResNetUnitGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(place), + true, + platform::errors::PreconditionNotMet("It must use XPUPlace.")); + + bool is_nchw = (ctx.Attr("data_format") == "NCHW"); + const Tensor *y_grad = ctx.Input(framework::GradVarName("Y")); + const Tensor *x = ctx.Input("X"); + const Tensor *filter_x = ctx.Input("FilterX"); + const Tensor *scale_x = ctx.Input("ScaleX"); + const Tensor *saved_mean_x = ctx.Input("SavedMeanX"); + const Tensor *saved_invstd_x = ctx.Input("SavedInvstdX"); + const Tensor *conv_out_x = ctx.Input("ConvX"); + const Tensor *output = ctx.Input("Y"); + + Tensor *x_grad = ctx.Output(framework::GradVarName("X")); + Tensor *filter_x_grad = + ctx.Output(framework::GradVarName("FilterX")); + Tensor *scale_x_grad = ctx.Output(framework::GradVarName("ScaleX")); + Tensor *bias_x_grad = ctx.Output(framework::GradVarName("BiasX")); + + int padding = ctx.Attr("padding"); + int stride = ctx.Attr("stride"); + int stride_z = ctx.Attr("stride_z"); + int dilation = ctx.Attr("dilation"); + int group = ctx.Attr("group"); + float eps = ctx.Attr("epsilon"); + bool has_shortcut = ctx.Attr("has_shortcut"); + bool fuse_add = ctx.Attr("fuse_add"); + std::string act_type = ctx.Attr("act_type"); + + auto &dev_ctx = ctx.template device_context(); + + std::vector x_list = {x->data()}; + std::vector w_list = {filter_x->data()}; + std::vector conv_y_list = {conv_out_x->data()}; + std::vector dx_list = {x_grad->mutable_data(place)}; + std::vector dw_list = {filter_x_grad->mutable_data(place)}; + + std::vector> x_shape_list = { + phi::vectorize(x->dims())}; + + auto filter_x_shape = phi::vectorize(filter_x->dims()); + std::vector x_ksize = {filter_x_shape[2], filter_x_shape[3]}; + if (!is_nchw) { + x_ksize[0] = filter_x_shape[1]; + x_ksize[1] = filter_x_shape[2]; + } + std::vector> ksize_list = {x_ksize}; + std::vector> stride_list = {{stride, stride}}; + std::vector paddings = {padding, padding}; + std::vector dilations = {dilation, dilation}; + + std::vector x_maxlist = {nullptr}; + std::vector w_maxlist = {nullptr}; + + std::vector scale_list = {scale_x->data()}; + std::vector batch_mean_list = {saved_mean_x->data()}; + std::vector batch_invstd_list = { + saved_invstd_x->data()}; + std::vector dscale_list = { + scale_x_grad->mutable_data(place)}; + std::vector dbias_list = {bias_x_grad->mutable_data(place)}; + + if (has_shortcut) { + // X Z + // | | + // NormConv NormConv + // | | + // BNStatsFinalize BNStatsFinalize + // \ / + // ScaleBiasAddRelu + // | + // Y + const Tensor *z = ctx.Input("Z"); + const Tensor *filter_z = ctx.Input("FilterZ"); + const Tensor *scale_z = ctx.Input("ScaleZ"); + const Tensor *saved_mean_z = ctx.Input("SavedMeanZ"); + const Tensor *saved_invstd_z = ctx.Input("SavedInvstdZ"); + const Tensor *conv_out_z = ctx.Input("ConvZ"); + + Tensor *z_grad = ctx.Output(framework::GradVarName("Z")); + Tensor *filter_z_grad = + ctx.Output(framework::GradVarName("FilterZ")); + Tensor *scale_z_grad = + ctx.Output(framework::GradVarName("ScaleZ")); + Tensor *bias_z_grad = ctx.Output(framework::GradVarName("BiasZ")); + x_list.push_back(z->data()); + w_list.push_back(filter_z->data()); + conv_y_list.push_back(conv_out_z->data()); + dx_list.push_back(z_grad->mutable_data(place)); + dw_list.push_back(filter_z_grad->mutable_data(place)); + x_shape_list.push_back(phi::vectorize(z->dims())); + + auto filter_z_shape = phi::vectorize(filter_z->dims()); + std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; + if (!is_nchw) { + ksize_z[0] = filter_z_shape[1]; + ksize_z[1] = filter_z_shape[2]; + } + ksize_list.push_back(ksize_z); + stride_list.push_back({stride_z, stride_z}); + x_maxlist.push_back(nullptr); + w_maxlist.push_back(nullptr); + + scale_list.push_back(scale_z->data()); + batch_mean_list.push_back(saved_mean_z->data()); + batch_invstd_list.push_back(saved_invstd_z->data()); + dscale_list.push_back(scale_z_grad->mutable_data(place)); + dbias_list.push_back(bias_z_grad->mutable_data(place)); + } else { + if (fuse_add) { + auto z_grad = ctx.Output(framework::GradVarName("Z")); + dx_list.push_back(z_grad->mutable_data(place)); + } + } + + int r = + xpu::resnet_unit_grad_fusion(dev_ctx.x_context(), + x_list, + w_list, + y_grad->data(), + output->data(), + conv_y_list, + dx_list, + dw_list, + x_shape_list, + filter_x_shape[0], + ksize_list, + stride_list, + paddings, + dilations, + group, + x_maxlist, + w_maxlist, + scale_list, + batch_mean_list, + batch_invstd_list, + dscale_list, + dbias_list, + xpu::Activation_t::RELU, + eps, + is_nchw, + has_shortcut, + fuse_add); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL(resnet_unit, ops::ResNetUnitXPUKernel); +REGISTER_OP_XPU_KERNEL(resnet_unit_grad, ops::ResNetUnitGradXPUKernel); diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index bd5957a1228..8cae8cfe534 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -374,6 +374,9 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"resnet_unit", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"resnet_unit_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rmsprop", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py index f58c0d4cf07..bcaa8055b25 100644 --- a/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py +++ b/python/paddle/fluid/tests/unittests/xpu/get_test_cover_info.py @@ -87,7 +87,9 @@ xpu_test_device_type_white_list = ['xpu1_float64'] xpu_test_op_type_white_list = [ 'dropout_float16', 'dropout_grad_float16', - "grad_add_float32" # no api for grad_add, skip + "grad_add_float32", # no api for grad_add, skip + "resnet_unit", + "resnet_unit_grad" ] xpu_test_device_op_white_list = [] xpu_test_device_op_type_white_list = [] diff --git a/python/paddle/incubate/operators/resnet_unit.py b/python/paddle/incubate/operators/resnet_unit.py index 6333ddafe10..70abe41f624 100644 --- a/python/paddle/incubate/operators/resnet_unit.py +++ b/python/paddle/incubate/operators/resnet_unit.py @@ -170,7 +170,7 @@ class ResNetUnit(Layer): self._is_test = is_test # check format - valid_format = {'NHWC'} + valid_format = {'NHWC', 'NCHW'} if data_format not in valid_format: raise ValueError( "conv_format must be one of {}, but got conv_format='{}'". @@ -181,11 +181,25 @@ class ResNetUnit(Layer): std = (2.0 / filter_elem_num)**0.5 return I.Normal(0.0, std) + is_nchw = (data_format == 'NCHW') # initial filter bn_param_dtype = fluid.core.VarDesc.VarType.FP32 - bn_param_shape = [1, 1, 1, num_filters] - filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x] - filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z] + if not is_nchw: + bn_param_shape = [1, 1, 1, num_filters] + filter_x_shape = [ + num_filters, filter_size, filter_size, num_channels_x + ] + filter_z_shape = [ + num_filters, filter_size, filter_size, num_channels_z + ] + else: + bn_param_shape = [1, num_filters, 1, 1] + filter_x_shape = [ + num_filters, num_channels_x, filter_size, filter_size + ] + filter_z_shape = [ + num_filters, num_channels_z, filter_size, filter_size + ] self.filter_x = self.create_parameter( shape=filter_x_shape, -- GitLab