diff --git a/paddle/fluid/operators/fused/resnet_basic_block_op.cc b/paddle/fluid/operators/fused/resnet_basic_block_op.cc index d54a889f93aa6a654c207563b4a47bc4d0cb9353..5990db8147be42f3588dfd76bebc5e8e53274591 100644 --- a/paddle/fluid/operators/fused/resnet_basic_block_op.cc +++ b/paddle/fluid/operators/fused/resnet_basic_block_op.cc @@ -258,24 +258,25 @@ class ResNetBasicBlockOp : public framework::OperatorWithKernel { class ResNetBasicBlockOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - // has_shortcut = True: X else: X - // / / - // | | | | - // CONV1 | CONV1 | - // | | | | - // BN1 | BN1 | - // | | | | - // RELU1 | RELU1 | - // | | | | - // CONV2 CONV3 CONV2 | - // | | | | - // BN2 BN3 BN2 | - // \ / \ / - // ADD ADD - // | | - // RELU RELU - // | | - // Y Y + // has_shortcut = True: else: + // X X + // / / + // | | | | + // CONV1 | CONV1 | + // | | | | + // BN1 | BN1 | + // | | | | + // RELU1 | RELU1 | + // | | | | + // CONV2 CONV3 CONV2 | + // | | | | + // BN2 BN3 BN2 | + // \ / \ / + // ADD ADD + // | | + // RELU RELU + // | | + // Y Y AddInput("X", "Input tensor of conv 1"); AddInput("Filter1", "Filter tensor of conv 1"); AddInput("Scale1", "Scale tensor of bn 1"); diff --git a/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc b/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7a6620c75f8e83b438c6d8a1813511eab6490bd --- /dev/null +++ b/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc @@ -0,0 +1,970 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/phi/api/all.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class ResnetBasicBlockAttr { + public: + explicit ResnetBasicBlockAttr(const framework::ExecutionContext& ctx) { + padding1 = ctx.Attr("padding1"); + padding2 = ctx.Attr("padding2"); + padding3 = ctx.Attr("padding3"); + stride1 = ctx.Attr("stride1"); + stride2 = ctx.Attr("stride2"); + stride3 = ctx.Attr("stride3"); + dilation1 = ctx.Attr("dilation1"); + dilation2 = ctx.Attr("dilation2"); + dilation3 = ctx.Attr("dilation3"); + group = ctx.Attr("group"); + + eps = static_cast(ctx.Attr("epsilon")); + momentum = static_cast(ctx.Attr("momentum")); + has_shortcut = ctx.Attr("has_shortcut"); + find_max = ctx.Attr("find_conv_input_max"); + + const auto is_test = ctx.Attr("is_test"); + const auto use_global_stats = ctx.Attr("use_global_stats"); + const auto trainable_stats = ctx.Attr("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + global_stats = test_mode || use_global_stats; + + // init shape + auto input1 = ctx.Input("X"); + auto filter1 = ctx.Input("Filter1"); + auto conv1_out = ctx.Output("Conv1"); + auto filter2 = ctx.Input("Filter2"); + auto conv2_out = ctx.Output("Conv2"); + conv1_input_shape = phi::vectorize(input1->dims()); + conv1_output_shape = phi::vectorize(conv1_out->dims()); + conv1_filter_shape = phi::vectorize(filter1->dims()); + conv1_filter_numel = filter1->numel(); + conv1_input_numel = input1->numel(); + conv1_output_numel = conv1_out->numel(); + + conv2_input_shape = phi::vectorize(conv1_out->dims()); + conv2_output_shape = phi::vectorize(conv2_out->dims()); + conv2_filter_shape = phi::vectorize(filter2->dims()); + conv2_filter_numel = filter2->numel(); + conv2_input_numel = conv1_out->numel(); + conv2_output_numel = conv2_out->numel(); + + if (has_shortcut) { + auto filter3 = ctx.Input("Filter3"); + auto conv3_out = ctx.Output("Conv3"); + conv3_input_shape = phi::vectorize(input1->dims()); + conv3_output_shape = phi::vectorize(conv3_out->dims()); + conv3_filter_shape = phi::vectorize(filter3->dims()); + conv3_filter_numel = filter3->numel(); + conv3_input_numel = input1->numel(); + conv3_output_numel = conv3_out->numel(); + } + } + + int padding1; + int padding2; + int padding3; + int stride1; + int stride2; + int stride3; + int dilation1; + int dilation2; + int dilation3; + int group; + + double eps; + double momentum; + + bool has_shortcut; + bool find_max; + bool global_stats; + + std::vector conv1_input_shape; + std::vector conv1_output_shape; + std::vector conv1_filter_shape; + std::vector conv2_input_shape; + std::vector conv2_output_shape; + std::vector conv2_filter_shape; + std::vector conv3_input_shape; + std::vector conv3_output_shape; + std::vector conv3_filter_shape; + + int conv1_filter_numel; + int conv2_filter_numel; + int conv3_filter_numel; + int conv1_input_numel; + int conv2_input_numel; + int conv3_input_numel; + int conv1_output_numel; + int conv2_output_numel; + int conv3_output_numel; +}; + +class ResnetBasicBlockGradAttr { + public: + explicit ResnetBasicBlockGradAttr(const framework::ExecutionContext& ctx) { + padding1 = ctx.Attr("padding1"); + padding2 = ctx.Attr("padding2"); + padding3 = ctx.Attr("padding3"); + stride1 = ctx.Attr("stride1"); + stride2 = ctx.Attr("stride2"); + stride3 = ctx.Attr("stride3"); + dilation1 = ctx.Attr("dilation1"); + dilation2 = ctx.Attr("dilation2"); + dilation3 = ctx.Attr("dilation3"); + group = ctx.Attr("group"); + + has_shortcut = ctx.Attr("has_shortcut"); + find_max = ctx.Attr("find_conv_input_max"); + + // init shape + auto input1 = ctx.Input("X"); + auto filter1 = ctx.Input("Filter1"); + auto conv1_out = ctx.Input("Conv1"); + auto filter2 = ctx.Input("Filter2"); + auto conv2_out = ctx.Input("Conv2"); + conv1_input_shape = phi::vectorize(input1->dims()); + conv1_output_shape = phi::vectorize(conv1_out->dims()); + conv1_filter_shape = phi::vectorize(filter1->dims()); + conv1_filter_numel = filter1->numel(); + conv1_input_numel = input1->numel(); + conv1_output_numel = conv1_out->numel(); + + conv2_input_shape = phi::vectorize(conv1_out->dims()); + conv2_output_shape = phi::vectorize(conv2_out->dims()); + conv2_filter_shape = phi::vectorize(filter2->dims()); + conv2_filter_numel = filter2->numel(); + conv2_input_numel = conv1_out->numel(); + conv2_output_numel = conv2_out->numel(); + + if (has_shortcut) { + auto filter3 = ctx.Input("Filter3"); + auto conv3_out = ctx.Input("Conv3"); + conv3_input_shape = phi::vectorize(input1->dims()); + conv3_output_shape = phi::vectorize(conv3_out->dims()); + conv3_filter_shape = phi::vectorize(filter3->dims()); + conv3_filter_numel = filter3->numel(); + conv3_input_numel = input1->numel(); + conv3_output_numel = conv3_out->numel(); + } + } + + int padding1; + int padding2; + int padding3; + int stride1; + int stride2; + int stride3; + int dilation1; + int dilation2; + int dilation3; + int group; + + bool has_shortcut; + bool find_max; + + std::vector conv1_input_shape; + std::vector conv1_output_shape; + std::vector conv1_filter_shape; + std::vector conv2_input_shape; + std::vector conv2_output_shape; + std::vector conv2_filter_shape; + std::vector conv3_input_shape; + std::vector conv3_output_shape; + std::vector conv3_filter_shape; + + int conv1_filter_numel; + int conv2_filter_numel; + int conv3_filter_numel; + int conv1_input_numel; + int conv2_input_numel; + int conv3_input_numel; + int conv1_output_numel; + int conv2_output_numel; + int conv3_output_numel; +}; + +template +static inline void xpu_conv2d(xpu::Context* ctx, + const T* input_data, + const T* filter_data, + T* output_data, + float* input_max_data, + float* filter_max_data, + const std::vector& input_shape, + const std::vector& filter_shape, + int padding, + int stride, + int dilation, + int group) { + std::vector ksize{filter_shape[2], filter_shape[3]}; + std::vector stride_vec{stride, stride}; + std::vector dilation_vec{dilation, dilation}; + std::vector padding_vec{padding, padding}; + int N = input_shape[0]; + int C = input_shape[1]; + int H = input_shape[2]; + int W = input_shape[3]; + + int r = xpu::conv2d(ctx, + input_data, + filter_data, + output_data, + N, + C, + H, + W, + filter_shape[0], + ksize, + stride_vec, + padding_vec, + dilation_vec, + group, + input_max_data, + filter_max_data, + nullptr, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d"); +} + +template +static inline void xpu_conv2d_grad(xpu::Context* ctx, + const T* input_data, + const T* filter_data, + const T* output_grad_data, + T* input_grad_data, + T* filter_grad_data, + const float* input_max_data, + const float* filter_max_data, + const std::vector& input_shape, + const std::vector& filter_shape, + int padding, + int stride, + int dilation, + int group) { + std::vector ksize{filter_shape[2], filter_shape[3]}; + std::vector stride_vec{stride, stride}; + std::vector dilation_vec{dilation, dilation}; + std::vector padding_vec{padding, padding}; + int N = input_shape[0]; + int C = input_shape[1]; + int H = input_shape[2]; + int W = input_shape[3]; + + int r = xpu::conv2d_grad(ctx, + input_data, + filter_data, + output_grad_data, + input_grad_data, + filter_grad_data, + N, + C, + H, + W, + filter_shape[0], + ksize, + stride_vec, + padding_vec, + dilation_vec, + group, + input_max_data, + filter_max_data, + nullptr, + nullptr, + nullptr, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); +} + +template +class ResNetBasicBlockXPUKernel : public framework::OpKernel { + public: + using XPUT = typename XPUTypeTrait::Type; + + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(ctx.GetPlace()), + true, + platform::errors::PreconditionNotMet("It must use XPUPlace.")); + + // input + const Tensor* x = ctx.Input("X"); + const Tensor* filter1 = ctx.Input("Filter1"); + const Tensor* scale1 = ctx.Input("Scale1"); + const Tensor* bias1 = ctx.Input("Bias1"); + const Tensor* filter2 = ctx.Input("Filter2"); + const Tensor* scale2 = ctx.Input("Scale2"); + const Tensor* bias2 = ctx.Input("Bias2"); + + // output + Tensor* conv1_output = ctx.Output("Conv1"); + Tensor* conv2_output = ctx.Output("Conv2"); + Tensor* conv2_input = ctx.Output("Conv2Input"); + Tensor* output = ctx.Output("Y"); + + auto place = ctx.GetPlace(); + auto x_data = reinterpret_cast(x->data()); + auto conv1_filter_data = reinterpret_cast(filter1->data()); + auto conv2_filter_data = reinterpret_cast(filter2->data()); + auto conv1_output_data = + reinterpret_cast(conv1_output->mutable_data(place)); + auto conv2_input_data = + reinterpret_cast(conv2_input->mutable_data(place)); + auto conv2_output_data = + reinterpret_cast(conv2_output->mutable_data(place)); + auto scale1_data = scale1->data(); + auto scale2_data = scale2->data(); + auto bias1_data = bias1->data(); + auto bias2_data = bias2->data(); + auto output_data = reinterpret_cast(output->mutable_data(place)); + + float* conv1_input_max_data = nullptr; + float* conv1_filter_max_data = nullptr; + float* conv2_input_max_data = nullptr; + float* conv2_filter_max_data = nullptr; + float* conv3_input_max_data = nullptr; + float* conv3_filter_max_data = nullptr; + + ResnetBasicBlockAttr attr(ctx); + + // init find max + if (attr.find_max) { + Tensor* max_input1 = ctx.Output("MaxInput1"); + Tensor* max_filter1 = ctx.Output("MaxFilter1"); + conv1_input_max_data = max_input1->mutable_data(place); + conv1_filter_max_data = max_filter1->mutable_data(place); + + Tensor* max_input2 = ctx.Output("MaxInput2"); + Tensor* max_filter2 = ctx.Output("MaxFilter2"); + conv2_input_max_data = max_input2->mutable_data(place); + conv2_filter_max_data = max_filter2->mutable_data(place); + + if (attr.has_shortcut) { + Tensor* max_input3 = ctx.Output("MaxInput3"); + Tensor* max_filter3 = ctx.Output("MaxFilter3"); + conv3_input_max_data = max_input3->mutable_data(place); + conv3_filter_max_data = max_filter3->mutable_data(place); + } + } + + auto& dev_ctx = ctx.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int r = XPU_SUCCESS; + + // 1. short + const XPUT* z_out_data = nullptr; + if (attr.has_shortcut) { + Tensor* conv3_out = ctx.Output("Conv3"); + const Tensor* filter3 = ctx.Input("Filter3"); + auto conv3_filter_data = + reinterpret_cast(filter3->data()); + auto conv3_output_data = + reinterpret_cast(conv3_out->mutable_data(place)); + + XPUT* conv3_input_l3_data = nullptr; + XPUT* conv3_filter_l3_data = + RAII_GUARD.alloc_l3(attr.conv3_filter_numel); + + if (attr.find_max) { + r = xpu::findmax_copy_fusion(dev_ctx.x_context(), + x_data, + conv3_input_max_data, + conv3_input_l3_data, + attr.conv3_input_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + + r = xpu::findmax_copy_fusion(dev_ctx.x_context(), + conv3_filter_data, + conv3_filter_max_data, + conv3_filter_l3_data, + attr.conv3_filter_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + } + + xpu_conv2d(dev_ctx.x_context(), + conv3_input_l3_data != nullptr ? conv3_input_l3_data : x_data, + conv3_filter_l3_data, + conv3_output_data, + conv3_input_max_data, + conv3_filter_max_data, + attr.conv3_input_shape, + attr.conv3_filter_shape, + attr.padding3, + attr.stride3, + attr.dilation3, + attr.group); + + // bn3 + const Tensor* scale3 = ctx.Input("Scale3"); + const Tensor* bias3 = ctx.Input("Bias3"); + auto bias3_data = bias3->data(); + auto scale3_data = scale3->data(); + + auto bn3_output_data = RAII_GUARD.alloc(attr.conv3_output_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(bn3_output_data); + + if (!attr.global_stats) { + Tensor* saved_mean3 = ctx.Output("SavedMean3"); + Tensor* saved_invstd3 = ctx.Output("SavedInvstd3"); + Tensor* running_mean3 = ctx.Output("Mean3Out"); + Tensor* running_var3 = ctx.Output("Var3Out"); + + auto saved_mean3_data = saved_mean3->mutable_data(place); + auto saved_invstd3_data = saved_invstd3->mutable_data(place); + auto running_mean3_data = running_mean3->mutable_data(place); + auto running_var3_data = running_var3->mutable_data(place); + + r = xpu::batch_norm_fusion(dev_ctx.x_context(), + conv3_output_data, + bn3_output_data, + attr.conv3_output_shape[0], + attr.conv3_output_shape[1], + attr.conv3_output_shape[3], + attr.conv3_output_shape[3], + attr.eps, + attr.momentum, + scale3_data, + bias3_data, + saved_mean3_data, + saved_invstd3_data, + running_mean3_data, + running_var3_data, + true, + nullptr, + xpu::Activation_t::LINEAR, + nullptr, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_fusion"); + } else { + const auto* mean3 = ctx.Input("Mean3"); + const auto* var3 = ctx.Input("Var3"); + const auto* mean3_data = mean3->data(); + const auto* variance3_data = var3->data(); + r = xpu::batch_norm_infer(dev_ctx.x_context(), + conv3_output_data, + bn3_output_data, + attr.conv3_output_shape[0], + attr.conv3_output_shape[1], + attr.conv3_output_shape[2], + attr.conv3_output_shape[3], + attr.eps, + scale3_data, + bias3_data, + mean3_data, + variance3_data, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); + } + z_out_data = reinterpret_cast(bn3_output_data); + } else { + z_out_data = x_data; + } + + // 2. conv1 + XPUT* conv1_input_l3_data = nullptr; + XPUT* conv1_filter_l3_data = + RAII_GUARD.alloc_l3(attr.conv1_filter_numel); + if (attr.find_max) { + r = xpu::findmax_copy_fusion(dev_ctx.x_context(), + x_data, + conv1_input_max_data, + conv1_input_l3_data, + attr.conv1_input_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + + r = xpu::findmax_copy_fusion(dev_ctx.x_context(), + conv1_filter_data, + conv1_filter_max_data, + conv1_filter_l3_data, + attr.conv1_filter_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + } + xpu_conv2d(dev_ctx.x_context(), + conv1_input_l3_data != nullptr ? conv1_input_l3_data : x_data, + conv1_filter_l3_data, + conv1_output_data, + conv1_input_max_data, + conv1_filter_max_data, + attr.conv1_input_shape, + attr.conv1_filter_shape, + attr.padding1, + attr.stride1, + attr.dilation1, + attr.group); + + // 3. bn1 + relu + if (!attr.global_stats) { + Tensor* saved_mean1 = ctx.Output("SavedMean1"); + Tensor* saved_invstd1 = ctx.Output("SavedInvstd1"); + Tensor* running_mean1 = ctx.Output("Mean1Out"); + Tensor* running_var1 = ctx.Output("Var1Out"); + + auto saved_mean1_data = saved_mean1->mutable_data(place); + auto saved_invstd1_data = saved_invstd1->mutable_data(place); + auto running_mean1_data = running_mean1->mutable_data(place); + auto running_var1_data = running_var1->mutable_data(place); + + r = xpu::batch_norm_fusion(dev_ctx.x_context(), + conv1_output_data, + conv2_input_data, + attr.conv1_output_shape[0], + attr.conv1_output_shape[1], + attr.conv1_output_shape[2], + attr.conv1_output_shape[3], + attr.eps, + attr.momentum, + scale1_data, + bias1_data, + saved_mean1_data, + saved_invstd1_data, + running_mean1_data, + running_var1_data, + true, + nullptr, + xpu::Activation_t::RELU, + nullptr, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_fusion"); + } else { + // bn --> relu + auto bn1_output_data = RAII_GUARD.alloc(attr.conv1_output_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(bn1_output_data); + + const auto* mean1 = ctx.Input("Mean1"); + const auto* var1 = ctx.Input("Var1"); + const auto* mean_data = mean1->data(); + const auto* variance_data = var1->data(); + r = xpu::batch_norm_infer(dev_ctx.x_context(), + conv1_output_data, + bn1_output_data, + attr.conv1_output_shape[0], + attr.conv1_output_shape[1], + attr.conv1_output_shape[2], + attr.conv1_output_shape[3], + attr.eps, + scale1_data, + bias1_data, + mean_data, + variance_data, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); + + r = xpu::relu(dev_ctx.x_context(), + bn1_output_data, + conv2_input_data, + attr.conv1_output_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); + } + + // 4. conv2 + XPUT* conv2_input_l3_data = nullptr; + XPUT* conv2_filter_l3_data = + RAII_GUARD.alloc_l3(attr.conv2_filter_numel); + if (attr.find_max) { + Tensor* max_input2 = ctx.Output("MaxInput2"); + Tensor* max_filter2 = ctx.Output("MaxFilter2"); + conv2_input_max_data = max_input2->mutable_data(place); + conv2_filter_max_data = max_filter2->mutable_data(place); + + r = xpu::findmax_copy_fusion(dev_ctx.x_context(), + conv2_input_data, + conv2_input_max_data, + conv2_input_l3_data, + attr.conv2_input_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + + r = xpu::findmax_copy_fusion(dev_ctx.x_context(), + conv2_filter_data, + conv2_filter_max_data, + conv2_filter_l3_data, + attr.conv2_filter_numel); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "findmax_copy_fusion"); + } + xpu_conv2d( + dev_ctx.x_context(), + conv2_input_l3_data != nullptr ? conv2_input_l3_data : conv2_input_data, + conv2_filter_l3_data, + conv2_output_data, + conv2_input_max_data, + conv2_filter_max_data, + attr.conv2_input_shape, + attr.conv2_filter_shape, + attr.padding2, + attr.stride2, + attr.dilation2, + attr.group); + + // 5. bn2 + if (!attr.global_stats) { + Tensor* saved_mean2 = ctx.Output("SavedMean2"); + Tensor* saved_var2 = ctx.Output("SavedInvstd2"); + Tensor* running_mean2 = ctx.Output("Mean2Out"); + Tensor* running_var2 = ctx.Output("Var2Out"); + + auto saved_mean2_data = saved_mean2->mutable_data(place); + auto saved_var2_data = saved_var2->mutable_data(place); + auto running_mean2_data = running_mean2->mutable_data(place); + auto running_var2_data = running_var2->mutable_data(place); + + r = xpu::batch_norm_fusion(dev_ctx.x_context(), + conv2_output_data, + output_data, + attr.conv2_output_shape[0], + attr.conv2_output_shape[1], + attr.conv2_output_shape[2], + attr.conv2_output_shape[3], + attr.eps, + attr.momentum, + scale2_data, + bias2_data, + saved_mean2_data, + saved_var2_data, + running_mean2_data, + running_var2_data, + true, + z_out_data, + xpu::Activation_t::RELU, + nullptr, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_fusion"); + } else { + auto bn2_out_data = RAII_GUARD.alloc(attr.conv2_output_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(bn2_out_data); + + const auto* mean2 = ctx.Input("Mean2"); + const auto* var2 = ctx.Input("Var2"); + const auto* mean_data = mean2->data(); + const auto* variance_data = var2->data(); + r = xpu::batch_norm_infer(dev_ctx.x_context(), + conv2_output_data, + bn2_out_data, + attr.conv2_output_shape[0], + attr.conv2_output_shape[1], + attr.conv2_output_shape[2], + attr.conv2_output_shape[3], + attr.eps, + scale2_data, + bias2_data, + mean_data, + variance_data, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); + + r = xpu::add_activation_fusion(dev_ctx.x_context(), + bn2_out_data, + z_out_data, + output_data, + output->numel(), + nullptr, + nullptr, + nullptr, + xpu::Activation_t::RELU); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_activation_fusion"); + } + } +}; + +template +class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { + public: + using XPUT = typename XPUTypeTrait::Type; + + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(ctx.GetPlace()), + true, + platform::errors::PreconditionNotMet("It must use XPUPlace.")); + + const Tensor* y_grad = ctx.Input(framework::GradVarName("Y")); + const Tensor* y = ctx.Input("Y"); + + const Tensor* x = ctx.Input("X"); + const Tensor* filter1 = ctx.Input("Filter1"); + const Tensor* scale1 = ctx.Input("Scale1"); + const Tensor* filter2 = ctx.Input("Filter2"); + const Tensor* scale2 = ctx.Input("Scale2"); + const Tensor* saved_mean1 = ctx.Input("SavedMean1"); + const Tensor* saved_invstd1 = ctx.Input("SavedInvstd1"); + const Tensor* saved_mean2 = ctx.Input("SavedMean2"); + const Tensor* saved_invstd2 = ctx.Input("SavedInvstd2"); + const Tensor* conv1_out = ctx.Input("Conv1"); + const Tensor* conv2_out = ctx.Input("Conv2"); + const Tensor* conv2_input = ctx.Input("Conv2Input"); + + const Tensor* filter3 = ctx.Input("Filter3"); + const Tensor* conv3_out = ctx.Input("Conv3"); + const Tensor* scale3 = ctx.Input("Scale3"); + const Tensor* saved_mean3 = ctx.Input("SavedMean3"); + const Tensor* saved_invstd3 = ctx.Input("SavedInvstd3"); + + const Tensor* conv1_input_max = ctx.Input("MaxInput1"); + const Tensor* conv1_filter_max = ctx.Input("MaxFilter1"); + const Tensor* conv2_input_max = ctx.Input("MaxInput2"); + const Tensor* conv2_filter_max = ctx.Input("MaxFilter2"); + const Tensor* conv3_input_max = ctx.Input("MaxInput3"); + const Tensor* conv3_filter_max = ctx.Input("MaxFilter3"); + + Tensor* x_grad = ctx.Output(framework::GradVarName("X")); + Tensor* filter1_grad = + ctx.Output(framework::GradVarName("Filter1")); + Tensor* scale1_grad = ctx.Output(framework::GradVarName("Scale1")); + Tensor* bias1_grad = ctx.Output(framework::GradVarName("Bias1")); + Tensor* filter2_grad = + ctx.Output(framework::GradVarName("Filter2")); + Tensor* scale2_grad = ctx.Output(framework::GradVarName("Scale2")); + Tensor* bias2_grad = ctx.Output(framework::GradVarName("Bias2")); + Tensor* filter3_grad = + ctx.Output(framework::GradVarName("Filter3")); + Tensor* scale3_grad = ctx.Output(framework::GradVarName("Scale3")); + Tensor* bias3_grad = ctx.Output(framework::GradVarName("Bias3")); + + // attrs + ResnetBasicBlockGradAttr attr(ctx); + auto place = ctx.GetPlace(); + + const auto* y_grad_data = reinterpret_cast(y_grad->data()); + const auto* y_data = reinterpret_cast(y->data()); + const auto* x_data = reinterpret_cast(x->data()); + const auto* conv1_output_data = + reinterpret_cast(conv1_out->data()); + const auto* conv1_filter_data = + reinterpret_cast(filter1->data()); + const auto* conv2_input_data = + reinterpret_cast(conv2_input->data()); + const auto* conv2_output_data = + reinterpret_cast(conv2_out->data()); + const auto* conv2_filter_data = + reinterpret_cast(filter2->data()); + + const auto* scale2_data = scale2->data(); + const auto* saved_mean2_data = saved_mean2->data(); + const auto* saved_invstd2_data = saved_invstd2->data(); + const auto* scale1_data = scale1->data(); + const auto* saved_mean1_data = saved_mean1->data(); + const auto* saved_invstd1_data = saved_invstd1->data(); + auto* scale2_grad_data = scale2_grad->mutable_data(place); + auto* bias2_grad_data = bias2_grad->mutable_data(place); + + const float* conv1_input_max_data = nullptr; + const float* conv1_filter_max_data = nullptr; + const float* conv2_input_max_data = nullptr; + const float* conv2_filter_max_data = nullptr; + const float* conv3_input_max_data = nullptr; + const float* conv3_filter_max_data = nullptr; + if (attr.find_max) { + conv1_input_max_data = + reinterpret_cast(conv1_input_max->data()); + conv1_filter_max_data = + reinterpret_cast(conv1_filter_max->data()); + conv2_input_max_data = + reinterpret_cast(conv2_input_max->data()); + conv2_filter_max_data = + reinterpret_cast(conv2_filter_max->data()); + if (attr.has_shortcut) { + conv3_input_max_data = + reinterpret_cast(conv3_input_max->data()); + conv3_filter_max_data = + reinterpret_cast(conv3_filter_max->data()); + } + } + + auto& dev_ctx = ctx.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int r = XPU_SUCCESS; + + // 0. bn2, bn2_fusion grad + auto conv2_output_grad_data = + RAII_GUARD.alloc(attr.conv2_output_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(conv2_output_grad_data); + + XPUT* z_output_grad_data = nullptr; + XPUT* z_grad_data = nullptr; + if (!attr.has_shortcut) { + z_output_grad_data = RAII_GUARD.alloc(attr.conv1_input_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(z_output_grad_data); + z_grad_data = z_output_grad_data; + } else { + z_output_grad_data = RAII_GUARD.alloc(attr.conv3_output_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(z_output_grad_data); + + z_grad_data = RAII_GUARD.alloc(attr.conv1_input_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(z_grad_data); + } + + r = xpu::batch_norm_grad_fusion(dev_ctx.x_context(), + conv2_output_data, + y_data, + y_grad_data, + conv2_output_grad_data, + attr.conv2_output_shape[0], + attr.conv2_output_shape[1], + attr.conv2_output_shape[2], + attr.conv2_output_shape[3], + scale2_data, + saved_mean2_data, + saved_invstd2_data, + scale2_grad_data, + bias2_grad_data, + true, + z_output_grad_data, + xpu::Activation_t::RELU, + nullptr, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad_fusion"); + + if (attr.has_shortcut) { + // bn3 grad + const auto* conv3_output_data = + reinterpret_cast(conv3_out->data()); + const auto* scale3_data = scale3->data(); + const auto* saved_mean3_data = saved_mean3->data(); + const auto* saved_invstd3_data = saved_invstd3->data(); + auto* scale3_grad_data = scale3_grad->mutable_data(place); + auto* bias3_grad_data = bias3_grad->mutable_data(place); + auto* conv3_output_grad_data = + RAII_GUARD.alloc(attr.conv3_output_numel); + + r = xpu::batch_norm_grad(dev_ctx.x_context(), + conv3_output_data, + z_output_grad_data, + conv3_output_grad_data, + attr.conv3_output_shape[0], + attr.conv3_output_shape[1], + attr.conv3_output_shape[2], + attr.conv3_output_shape[3], + scale3_data, + saved_mean3_data, + saved_invstd3_data, + scale3_grad_data, + bias3_grad_data, + true); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad"); + + // conv3 grad + auto* conv3_filter_grad_data = + reinterpret_cast(filter3_grad->mutable_data(place)); + auto* conv3_filter_data = + reinterpret_cast(filter3->data()); + xpu_conv2d_grad(dev_ctx.x_context(), + x_data, + conv3_filter_data, + conv3_output_grad_data, + z_grad_data, + conv3_filter_grad_data, + conv3_input_max_data, + conv3_filter_max_data, + attr.conv3_input_shape, + attr.conv3_filter_shape, + attr.padding3, + attr.stride3, + attr.dilation3, + attr.group); + } + + // 2. conv2_grad + auto* conv2_filter_grad_data = + reinterpret_cast(filter2_grad->mutable_data(place)); + auto* conv2_input_grad_data = + RAII_GUARD.alloc(attr.conv2_input_numel); + xpu_conv2d_grad(dev_ctx.x_context(), + conv2_input_data, + conv2_filter_data, + conv2_output_grad_data, + conv2_input_grad_data, + conv2_filter_grad_data, + conv2_input_max_data, + conv2_filter_max_data, + attr.conv2_input_shape, + attr.conv2_filter_shape, + attr.padding2, + attr.stride2, + attr.dilation2, + attr.group); + + // 3. b1 grad + auto* conv1_output_grad_data = + RAII_GUARD.alloc(attr.conv1_output_numel); + PADDLE_ENFORCE_XDNN_NOT_NULL(conv1_output_grad_data); + auto* scale1_grad_data = scale1_grad->mutable_data(ctx.GetPlace()); + auto* bias1_grad_data = bias1_grad->mutable_data(ctx.GetPlace()); + r = xpu::batch_norm_grad_fusion(dev_ctx.x_context(), + conv1_output_data, + conv2_input_data, + conv2_input_grad_data, + conv1_output_grad_data, + attr.conv1_output_shape[0], + attr.conv1_output_shape[1], + attr.conv1_output_shape[2], + attr.conv1_output_shape[3], + scale1_data, + saved_mean1_data, + saved_invstd1_data, + scale1_grad_data, + bias1_grad_data, + true, + nullptr, + xpu::Activation_t::RELU, + nullptr, + 0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad_fusion"); + + // 4. conv1_grad + auto* x_grad_data = reinterpret_cast(x_grad->mutable_data(place)); + auto* conv1_filter_grad_data = + reinterpret_cast(filter1_grad->mutable_data(place)); + xpu_conv2d_grad(dev_ctx.x_context(), + x_data, + conv1_filter_data, + conv1_output_grad_data, + x_grad_data, + conv1_filter_grad_data, + conv1_input_max_data, + conv1_filter_max_data, + attr.conv1_input_shape, + attr.conv1_filter_shape, + attr.padding1, + attr.stride1, + attr.dilation1, + attr.group); + + // add z_grad to x_grad + r = xpu::add( + dev_ctx.x_context(), x_grad_data, z_grad_data, x_grad_data, x->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL( + resnet_basic_block, + ops::ResNetBasicBlockXPUKernel, + ops::ResNetBasicBlockXPUKernel); +REGISTER_OP_XPU_KERNEL( + resnet_basic_block_grad, + ops::ResNetBasicBlockGradXPUKernel, + ops::ResNetBasicBlockGradXPUKernel); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 2b80396cc3138f47d15026cd5c9c167e3988470b..204cb0015048df354e32fab385f73c25f8c82a79 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -505,6 +505,14 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sequence_conv_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + + // Fused op + {"resnet_basic_block_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, + {"resnet_basic_block", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, }; return s_xpu2_kernels;