/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template class ResNetUnitXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto place = ctx.GetPlace(); PADDLE_ENFORCE_EQ( platform::is_xpu_place(place), true, platform::errors::PreconditionNotMet("It must use XPUPlace.")); bool is_nchw = (ctx.Attr("data_format") == "NCHW"); // input x const Tensor *input_x = ctx.Input("X"); const Tensor *filter_x = ctx.Input("FilterX"); const Tensor *scale_x = ctx.Input("ScaleX"); const Tensor *bias_x = ctx.Input("BiasX"); // output x Tensor *conv_out_x = ctx.Output("ConvX"); Tensor *saved_mean_x = ctx.Output("SavedMeanX"); Tensor *saved_invstd_x = ctx.Output("SavedInvstdX"); Tensor *running_mean_x = ctx.Output("RunningMeanX"); Tensor *running_var_x = ctx.Output("RunningVarX"); Tensor *output = ctx.Output("Y"); // attrs int padding = ctx.Attr("padding"); int stride = ctx.Attr("stride"); int stride_z = ctx.Attr("stride_z"); int dilation = ctx.Attr("dilation"); int group = ctx.Attr("group"); float eps = ctx.Attr("epsilon"); float momentum = ctx.Attr("momentum"); bool has_shortcut = ctx.Attr("has_shortcut"); bool fuse_add = ctx.Attr("fuse_add"); bool use_global_stats = ctx.Attr("use_global_stats"); bool is_test = ctx.Attr("is_test"); bool is_train = !is_test && !use_global_stats; std::string act_type = ctx.Attr("act_type"); auto &dev_ctx = ctx.template device_context(); std::vector x_list = {input_x->data()}; std::vector w_list = {filter_x->data()}; std::vector conv_y_list = {conv_out_x->mutable_data(place)}; std::vector> x_shape_list = { phi::vectorize(input_x->dims())}; auto filter_x_shape = phi::vectorize(filter_x->dims()); std::vector ksize = {filter_x_shape[2], filter_x_shape[3]}; if (!is_nchw) { ksize[0] = filter_x_shape[1]; ksize[1] = filter_x_shape[2]; } std::vector strides = {stride, stride}; std::vector> ksize_list = {ksize}; std::vector> stride_list = {strides}; std::vector paddings = {padding, padding}; std::vector dilations = {dilation, dilation}; std::vector scale_list = {scale_x->data()}; std::vector bias_list = {bias_x->data()}; std::vector batch_mean_list = { saved_mean_x->mutable_data(place)}; std::vector batch_invstd_list = { saved_invstd_x->mutable_data(place)}; std::vector global_mean_list = { running_mean_x->mutable_data(place)}; std::vector global_var_list = { running_var_x->mutable_data(place)}; std::vector x_maxlist = {nullptr}; std::vector w_maxlist = {nullptr}; if (has_shortcut) { // input z const Tensor *input_z = ctx.Input("Z"); const Tensor *filter_z = ctx.Input("FilterZ"); const Tensor *scale_z = ctx.Input("ScaleZ"); const Tensor *bias_z = ctx.Input("BiasZ"); Tensor *conv_out_z = ctx.Output("ConvZ"); Tensor *saved_mean_z = ctx.Output("SavedMeanZ"); Tensor *saved_invstd_z = ctx.Output("SavedInvstdZ"); Tensor *running_mean_z = ctx.Output("RunningMeanZ"); Tensor *running_var_z = ctx.Output("RunningVarZ"); x_list.push_back(input_z->data()); w_list.push_back(filter_z->data()); conv_y_list.push_back(conv_out_z->mutable_data(place)); x_shape_list.push_back(phi::vectorize(input_z->dims())); auto filter_z_shape = phi::vectorize(filter_z->dims()); std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; if (!is_nchw) { ksize_z[0] = filter_z_shape[1]; ksize_z[1] = filter_z_shape[2]; } ksize_list.push_back(ksize_z); stride_list.push_back({stride_z, stride_z}); scale_list.push_back(scale_z->data()); bias_list.push_back(bias_z->data()); batch_mean_list.push_back(saved_mean_z->mutable_data(place)); batch_invstd_list.push_back(saved_invstd_z->mutable_data(place)); global_mean_list.push_back(running_mean_z->mutable_data(place)); global_var_list.push_back(running_var_z->mutable_data(place)); x_maxlist.push_back(nullptr); w_maxlist.push_back(nullptr); } else { if (fuse_add) { const Tensor *input_z = ctx.Input("Z"); auto input_z_shape = phi::vectorize(input_z->dims()); x_list.push_back(input_z->data()); x_shape_list.push_back(input_z_shape); x_maxlist.push_back(nullptr); } } int r = xpu::resnet_unit_fusion( dev_ctx.x_context(), x_list, w_list, conv_y_list, output->mutable_data(place), x_shape_list, filter_x_shape[0], ksize_list, stride_list, paddings, dilations, group, eps, momentum, x_maxlist, w_maxlist, scale_list, bias_list, batch_mean_list, batch_invstd_list, global_mean_list, global_var_list, xpu::Activation_t::RELU, is_nchw, has_shortcut, fuse_add, is_train); PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_fusion"); } }; template class ResNetUnitGradXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto place = ctx.GetPlace(); PADDLE_ENFORCE_EQ( platform::is_xpu_place(place), true, platform::errors::PreconditionNotMet("It must use XPUPlace.")); bool is_nchw = (ctx.Attr("data_format") == "NCHW"); const Tensor *y_grad = ctx.Input(framework::GradVarName("Y")); const Tensor *x = ctx.Input("X"); const Tensor *filter_x = ctx.Input("FilterX"); const Tensor *scale_x = ctx.Input("ScaleX"); const Tensor *saved_mean_x = ctx.Input("SavedMeanX"); const Tensor *saved_invstd_x = ctx.Input("SavedInvstdX"); const Tensor *conv_out_x = ctx.Input("ConvX"); const Tensor *output = ctx.Input("Y"); Tensor *x_grad = ctx.Output(framework::GradVarName("X")); Tensor *filter_x_grad = ctx.Output(framework::GradVarName("FilterX")); Tensor *scale_x_grad = ctx.Output(framework::GradVarName("ScaleX")); Tensor *bias_x_grad = ctx.Output(framework::GradVarName("BiasX")); int padding = ctx.Attr("padding"); int stride = ctx.Attr("stride"); int stride_z = ctx.Attr("stride_z"); int dilation = ctx.Attr("dilation"); int group = ctx.Attr("group"); float eps = ctx.Attr("epsilon"); bool has_shortcut = ctx.Attr("has_shortcut"); bool fuse_add = ctx.Attr("fuse_add"); std::string act_type = ctx.Attr("act_type"); auto &dev_ctx = ctx.template device_context(); std::vector x_list = {x->data()}; std::vector w_list = {filter_x->data()}; std::vector conv_y_list = {conv_out_x->data()}; std::vector dx_list = {x_grad->mutable_data(place)}; std::vector dw_list = {filter_x_grad->mutable_data(place)}; std::vector> x_shape_list = { phi::vectorize(x->dims())}; auto filter_x_shape = phi::vectorize(filter_x->dims()); std::vector x_ksize = {filter_x_shape[2], filter_x_shape[3]}; if (!is_nchw) { x_ksize[0] = filter_x_shape[1]; x_ksize[1] = filter_x_shape[2]; } std::vector> ksize_list = {x_ksize}; std::vector> stride_list = {{stride, stride}}; std::vector paddings = {padding, padding}; std::vector dilations = {dilation, dilation}; std::vector x_maxlist = {nullptr}; std::vector w_maxlist = {nullptr}; std::vector scale_list = {scale_x->data()}; std::vector batch_mean_list = {saved_mean_x->data()}; std::vector batch_invstd_list = { saved_invstd_x->data()}; std::vector dscale_list = { scale_x_grad->mutable_data(place)}; std::vector dbias_list = {bias_x_grad->mutable_data(place)}; if (has_shortcut) { // X Z // | | // NormConv NormConv // | | // BNStatsFinalize BNStatsFinalize // \ / // ScaleBiasAddRelu // | // Y const Tensor *z = ctx.Input("Z"); const Tensor *filter_z = ctx.Input("FilterZ"); const Tensor *scale_z = ctx.Input("ScaleZ"); const Tensor *saved_mean_z = ctx.Input("SavedMeanZ"); const Tensor *saved_invstd_z = ctx.Input("SavedInvstdZ"); const Tensor *conv_out_z = ctx.Input("ConvZ"); Tensor *z_grad = ctx.Output(framework::GradVarName("Z")); Tensor *filter_z_grad = ctx.Output(framework::GradVarName("FilterZ")); Tensor *scale_z_grad = ctx.Output(framework::GradVarName("ScaleZ")); Tensor *bias_z_grad = ctx.Output(framework::GradVarName("BiasZ")); x_list.push_back(z->data()); w_list.push_back(filter_z->data()); conv_y_list.push_back(conv_out_z->data()); dx_list.push_back(z_grad->mutable_data(place)); dw_list.push_back(filter_z_grad->mutable_data(place)); x_shape_list.push_back(phi::vectorize(z->dims())); auto filter_z_shape = phi::vectorize(filter_z->dims()); std::vector ksize_z = {filter_z_shape[2], filter_z_shape[3]}; if (!is_nchw) { ksize_z[0] = filter_z_shape[1]; ksize_z[1] = filter_z_shape[2]; } ksize_list.push_back(ksize_z); stride_list.push_back({stride_z, stride_z}); x_maxlist.push_back(nullptr); w_maxlist.push_back(nullptr); scale_list.push_back(scale_z->data()); batch_mean_list.push_back(saved_mean_z->data()); batch_invstd_list.push_back(saved_invstd_z->data()); dscale_list.push_back(scale_z_grad->mutable_data(place)); dbias_list.push_back(bias_z_grad->mutable_data(place)); } else { if (fuse_add) { auto z_grad = ctx.Output(framework::GradVarName("Z")); dx_list.push_back(z_grad->mutable_data(place)); } } int r = xpu::resnet_unit_grad_fusion(dev_ctx.x_context(), x_list, w_list, y_grad->data(), output->data(), conv_y_list, dx_list, dw_list, x_shape_list, filter_x_shape[0], ksize_list, stride_list, paddings, dilations, group, x_maxlist, w_maxlist, scale_list, batch_mean_list, batch_invstd_list, dscale_list, dbias_list, xpu::Activation_t::RELU, eps, is_nchw, has_shortcut, fuse_add); PADDLE_ENFORCE_XDNN_SUCCESS(r, "resnet_unit_grad_fusion"); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL(resnet_unit, ops::ResNetUnitXPUKernel); REGISTER_OP_XPU_KERNEL(resnet_unit_grad, ops::ResNetUnitGradXPUKernel);