From 4922376ccdd0ea1de8b9b9e408fdece92c484735 Mon Sep 17 00:00:00 2001 From: dongfangshenzhu <102794151+dongfangshenzhu@users.noreply.github.com> Date: Thu, 4 Aug 2022 15:24:24 +0800 Subject: [PATCH] [XPU] add merged_momentum including fp32 and fp16 (#44824) * add merged_momentum *test=kunlun * add merged_momentum *test=kunlun * add fp16 to merged_momentum,*test=kunlun --- cmake/external/xpu.cmake | 4 +- .../optimizers/merged_momentum_op_xpu.cc | 141 ++++++++++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 3 + 3 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 395efda6c6..87f686a8ab 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") if(NOT DEFINED XPU_BASE_URL) set(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220731") + set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220802") else() set(XPU_BASE_URL "${XPU_BASE_URL}") endif() @@ -19,7 +19,7 @@ endif() if(NOT DEFINED XPU_XDNN_BASE_URL) set(XPU_XDNN_BASE_URL_WITHOUT_DATE "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") - set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220731") + set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220802") else() set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") endif() diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc new file mode 100644 index 0000000000..3993a46add --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_momentum_op_xpu.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifdef PADDLE_WITH_XPU +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/phi/kernels/impl/momentum_kernel_impl.h" +namespace paddle { +namespace operators { + +template +class MergedMomentumOpXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + T mu = static_cast(ctx.Attr("mu")); + auto params = ctx.MultiInput("Param"); + auto params_out = ctx.MultiOutput("ParamOut"); + auto lr = ctx.Input("LearningRate"); + int op_num = params.size(); + auto velocity = ctx.MultiInput("Velocity"); + auto grad = ctx.MultiInput("Grad"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto use_nesterov = ctx.Attr("use_nesterov"); + auto regularization_method = + ctx.Attr>("regularization_method"); + auto regularization_coeff = + ctx.Attr>("regularization_coeff"); + std::vector param_list(op_num); + std::vector velocity_list(op_num); + std::vector grad_list(op_num); + std::vector velocity_out_list(op_num); + std::vector param_out_list(op_num); + std::vector sizes(op_num); + std::vector l2_weight_decay(op_num); + if (op_num > 0) { + for (int j = 0; j < op_num; j++) { + param_list[j] = + reinterpret_cast(const_cast(params[j]->data())); + velocity_list[j] = + reinterpret_cast(const_cast(velocity[j]->data())); + grad_list[j] = + reinterpret_cast(const_cast(grad[j]->data())); + param_out_list[j] = + reinterpret_cast(params_out[j]->data()); + velocity_out_list[j] = + reinterpret_cast(velocity_out[j]->data()); + sizes[j] = static_cast(params[j]->numel()); + if (regularization_method[j] != "l2_decay") { + l2_weight_decay[j] = 0.0f; + } else { + l2_weight_decay[j] = static_cast(regularization_coeff[j]); + } + PADDLE_ENFORCE_EQ(params[j], + params_out[j], + platform::errors::InvalidArgument( + "The size of Input(Param) and Output(ParamOut) " + "must be the same Tensors.")); + PADDLE_ENFORCE_EQ( + velocity[j], + velocity_out[j], + platform::errors::InvalidArgument( + "The size of Input(velocity) and Output(velocity) " + "must be the same Tensors.")); + } + } else { + return; + } + auto& dev_ctx = ctx.template device_context(); + PADDLE_ENFORCE_EQ(op_num, + params_out.size(), + platform::errors::InvalidArgument( + "The size of Output(ParamOut) must be equal to " + "Input(Param), but got the size of Output(ParamOut) " + "is %d, the size of Input(Param) is %d.", + params_out.size(), + op_num)); + PADDLE_ENFORCE_EQ(op_num, + velocity.size(), + platform::errors::InvalidArgument( + "The size of Output(Velocity) must be equal to " + "Input(Param), but got the size of Output(Velocity) " + "is %d, the size of Input(Param) is %d.", + velocity.size(), + op_num)); + PADDLE_ENFORCE_EQ( + op_num, + velocity_out.size(), + platform::errors::InvalidArgument( + "The size of Output(VelocityOut) must be equal to " + "Input(Param), but got the size of Output(VelocityOut) " + "is %d, the size of Input(Param) is %d.", + velocity_out.size(), + op_num)); + PADDLE_ENFORCE_EQ( + op_num, + grad.size(), + platform::errors::InvalidArgument( + "The size of Input(Grad) must be equal to Input(Param), but got " + "the size of Input(Grad) is %d, the size of Input(Param) is %d.", + grad.size(), + op_num)); + int r = xpu::merged_momentum(dev_ctx.x_context(), + param_list, + velocity_list, + grad_list, + param_out_list, + velocity_out_list, + l2_weight_decay, + sizes, + lr->data(), + mu, + use_nesterov); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_momentum"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + merged_momentum, + ops::MergedMomentumOpXPUKernel, + ops::MergedMomentumOpXPUKernel); +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index c36dd6425c..b3752a9ec8 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -312,6 +312,9 @@ XPUOpMap& get_kl2_ops() { {"mean", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"merged_momentum", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"mul", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), -- GitLab