From b0c2ee2661d04750cc44c472a0a33fd3dbbc809d Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Tue, 29 Aug 2023 10:55:54 +0800 Subject: [PATCH] [Fluid] move lars_momentum to phi (#55798) * [Fluid] move lars_momentum to phi * add sig * fix optional Output * off check_dygraph * fix input * fix operator[] * fix * try fix AllocateTmpTensor * fix * fix type * Update paddle/phi/kernels/gpu/lars_momentum_kernel.cu * fix type * rollback * Add Registration * try fix win * try fix win * try use double * try use operator *(float,const Derived &) * try auto * fix * fix * fix * fix dtype * fix type * fix index --- .../operators/optimizers/lars_momentum_op.cc | 6 +- .../operators/optimizers/lars_momentum_op.h | 74 --- .../optimizers/lars_momentum_op_xpu.cc | 2 +- .../phi/kernels/cpu/lars_momentum_kernel.cc | 78 +++ .../kernels/gpu/lars_momentum_kernel.cu} | 449 +++++++++--------- paddle/phi/kernels/lars_momentum_kernel.h | 40 ++ paddle/phi/ops/compat/lars_momentum_sig.cc | 35 ++ test/legacy_test/test_momentum_op.py | 2 +- 8 files changed, 381 insertions(+), 305 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/lars_momentum_op.h create mode 100644 paddle/phi/kernels/cpu/lars_momentum_kernel.cc rename paddle/{fluid/operators/optimizers/lars_momentum_op.cu => phi/kernels/gpu/lars_momentum_kernel.cu} (64%) create mode 100644 paddle/phi/kernels/lars_momentum_kernel.h create mode 100644 paddle/phi/ops/compat/lars_momentum_sig.cc diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index b6f067e1a2c..e6b04a8a3ca 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -233,6 +234,3 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::LarsMomentumOpVarTypeInference); - -PD_REGISTER_STRUCT_KERNEL( - lars_momentum, CPU, ALL_LAYOUT, ops::LarsMomentumOpKernel, float, double) {} diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h deleted file mode 100644 index 70bf0c9186b..00000000000 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -template -class LarsMomentumOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.MultiOutput("ParamOut"); - auto velocity_out = ctx.MultiOutput("VelocityOut"); - auto param = ctx.MultiInput("Param"); - auto velocity = ctx.MultiInput("Velocity"); - auto learning_rate = ctx.MultiInput("LearningRate"); - auto grad = ctx.MultiInput("Grad"); - auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); - T mu = static_cast(ctx.Attr("mu")); - T lars_coeff = ctx.Attr("lars_coeff"); - T epsilon = ctx.Attr("epsilon"); - T rescale_grad = ctx.Attr("rescale_grad"); - - int op_num = param.size(); - for (int i = 0; i < op_num; ++i) { - auto* lr = learning_rate[i]->data(); - T lars_weight_decay = weight_decay_arr[i]; - param_out[i]->mutable_data(ctx.GetPlace()); - velocity_out[i]->mutable_data(ctx.GetPlace()); - - auto p_out = framework::EigenVector::Flatten(*(param_out[i])); - auto v_out = framework::EigenVector::Flatten(*(velocity_out[i])); - auto p = framework::EigenVector::Flatten(*(param[i])); - auto v = framework::EigenVector::Flatten(*(velocity[i])); - auto g = framework::EigenVector::Flatten(*(grad[i])); - auto rescale_g = rescale_grad * g; - - phi::DenseTensor p_norm_t, g_norm_t; - p_norm_t.Resize({1}); - g_norm_t.Resize({1}); - p_norm_t.mutable_data(ctx.GetPlace()); - g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); - ep_norm = p.square().sum().sqrt(); - eg_norm = rescale_g.square().sum().sqrt(); - - T local_lr = lr[0]; - if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { - local_lr = lr[0] * lars_coeff * ep_norm(0) / - (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); - } - v_out = v * mu + local_lr * (rescale_g + lars_weight_decay * p); - p_out = p - v_out; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc index 52b57252b0a..266ce2e57ca 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op_xpu.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" namespace paddle { diff --git a/paddle/phi/kernels/cpu/lars_momentum_kernel.cc b/paddle/phi/kernels/cpu/lars_momentum_kernel.cc new file mode 100644 index 00000000000..c1f8a0a5fee --- /dev/null +++ b/paddle/phi/kernels/cpu/lars_momentum_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lars_momentum_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void LarsMomentumKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& velocity, + const std::vector& learning_rate, + const std::vector& grad, + const paddle::optional>& master_param, + const std::vector& weight_decay_arr, + float mu, + float lars_coeff, + float epsilon, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) { + int op_num = param.size(); + T mu_ = static_cast(mu); + for (int i = 0; i < op_num; ++i) { + auto* lr = learning_rate[i]->data(); + T lars_weight_decay = weight_decay_arr[i]; + dev_ctx.template Alloc(param_out[i]); + dev_ctx.template Alloc(velocity_out[i]); + + auto p_out = phi::EigenVector::Flatten(*(param_out[i])); + auto v_out = phi::EigenVector::Flatten(*(velocity_out[i])); + auto p = phi::EigenVector::Flatten(*(param[i])); + auto v = phi::EigenVector::Flatten(*(velocity[i])); + Eigen::TensorMap> g = + phi::EigenVector::Flatten(*(grad[i])); + auto rescale_g = static_cast(rescale_grad) * g; + + phi::DenseTensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + dev_ctx.template Alloc(&p_norm_t); + dev_ctx.template Alloc(&g_norm_t); + auto ep_norm = phi::EigenScalar::From(p_norm_t); + auto eg_norm = phi::EigenScalar::From(g_norm_t); + ep_norm = p.square().sum().sqrt(); + eg_norm = rescale_g.square().sum().sqrt(); + + T local_lr = lr[0]; + if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + } + v_out = v * mu_ + local_lr * (rescale_g + lars_weight_decay * p); + p_out = p - v_out; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + lars_momentum, CPU, ALL_LAYOUT, phi::LarsMomentumKernel, float, double) {} diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/phi/kernels/gpu/lars_momentum_kernel.cu similarity index 64% rename from paddle/fluid/operators/optimizers/lars_momentum_op.cu rename to paddle/phi/kernels/gpu/lars_momentum_kernel.cu index 20769747a4a..14b7f1ca328 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/phi/kernels/gpu/lars_momentum_kernel.cu @@ -1,22 +1,25 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" -#include "paddle/fluid/framework/op_registry.h" +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lars_momentum_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" +#include "paddle/utils/optional.h" #if CUDA_VERSION >= 11000 #include @@ -30,8 +33,7 @@ limitations under the License. */ #define LARS_MAX_MERGED_OPS 60 -namespace paddle { -namespace operators { +namespace phi { template using MultiPrecisionType = typename phi::dtype::MPTypeTrait::Type; @@ -253,7 +255,7 @@ __forceinline__ __device__ void MomentumUpdate( master_param_out); } else { if (std::is_same::value || - std::is_same::value) { + std::is_same::value) { /* TODO(limingshu): pointer cast may damage memory accessing for fp16 */ VectorizeLarsUpdate( grad, @@ -419,7 +421,7 @@ __global__ void MomentumLarsKernel(const T* param, } template -inline void SeparatedLarsMomentumOpCUDAKernel(const phi::GPUContext& cuda_ctx, +inline void SeparatedLarsMomentumOpCUDAKernel(const GPUContext& cuda_ctx, const T* param_data, T* param_out_data, const MT* velocity_data, @@ -474,216 +476,213 @@ inline void SeparatedLarsMomentumOpCUDAKernel(const phi::GPUContext& cuda_ctx, is_amp); } -template -class LarsMomentumOpCUDAKernel : public framework::OpKernel { +template +void LarsMomentumKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& velocity, + const std::vector& learning_rate, + const std::vector& grad, + const paddle::optional>& master_param, + const std::vector& weight_decay_arr, + float mu, + float lars_coeff, + float epsilon, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) { using MT = MultiPrecisionType; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - int num_blocks_per_sm = 0; - bool multi_precision = ctx.Attr("multi_precision"); - auto& cuda_ctx = ctx.template device_context(); - int sm_num = cuda_ctx.GetSMCount(); - phi::DenseTensor tmp_buffer_t = ctx.AllocateTmpTensor( - {LARS_BLOCK_SIZE << 1}, cuda_ctx); - auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); - auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; - - MT mu = static_cast(ctx.Attr("mu")); - MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); - MT epsilon = static_cast(ctx.Attr("epsilon")); - MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - - auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); - auto grad = ctx.MultiInput("Grad"); - auto param = ctx.MultiInput("Param"); - auto velocity = ctx.MultiInput("Velocity"); - auto param_out = ctx.MultiOutput("ParamOut"); - auto velocity_out = ctx.MultiOutput("VelocityOut"); - auto learning_rate = ctx.MultiInput("LearningRate"); - auto master_param = ctx.MultiInput("MasterParam"); - auto master_param_out = ctx.MultiOutput("MasterParamOut"); - - int op_num = grad.size(); + int num_blocks_per_sm = 0; + int sm_num = dev_ctx.GetSMCount(); + // phi::DenseTensor tmp_buffer_t = ctx.AllocateTmpTensor( + // {LARS_BLOCK_SIZE << 1}, cuda_ctx); + phi::DenseTensor tmp_buffer_t; + tmp_buffer_t.Resize({LARS_BLOCK_SIZE << 1}); + MT* p_buffer = dev_ctx.template Alloc(&tmp_buffer_t); + MT* g_buffer = p_buffer + LARS_BLOCK_SIZE; + + MT mu_ = static_cast(mu); + MT lars_coeff_ = static_cast(lars_coeff); + MT epsilon_ = static_cast(epsilon); + MT rescale_grad_ = static_cast(rescale_grad); + + int op_num = grad.size(); #if CUDA_VERSION >= 11000 - if (op_num > 1) { - LarsParamWarpper lars_warpper; - PADDLE_ENFORCE_LT( - op_num, - LARS_MAX_MERGED_OPS, - platform::errors::InvalidArgument( - "The maximum number of merged-ops supported is (%d), but" - "lars op required for trainning this model is (%d)\n", - LARS_MAX_MERGED_OPS, - op_num)); - - /* Implementation of lars optimizer consists of following two steps: - 1. Figure out the L2 norm statistic result of grad data and param data. - 2. Update param and velocity with usage of L2 norm statistic result. - Step1 and step2 can be merged with api provided by nvida - cudaLaunchCooperativeKernel: - - The thread quantity shall less than pyhsical SM limited threads - - Launche as thread-block can synchronizlly execute. */ - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, - MergedMomentumLarsKernel, - LARS_BLOCK_SIZE, - sizeof(MT) << 1); - - size_t total_numel = 0; - for (int i = 0; i < op_num; ++i) { - size_t temp_numel = param[i]->numel(); - total_numel += temp_numel; - lars_warpper.numel_arr[i] = temp_numel; - lars_warpper.g_arr[i] = grad[i]->data(); - lars_warpper.lr_arr[i] = learning_rate[i]->data(); - lars_warpper.p_out_arr[i] = - param_out[i]->mutable_data(ctx.GetPlace()); - lars_warpper.v_out_arr[i] = - velocity_out[i]->mutable_data(ctx.GetPlace()); - lars_warpper.weight_decay_arr[i] = static_cast(weight_decay_arr[i]); - PADDLE_ENFORCE_EQ( - param[i]->data(), - lars_warpper.p_out_arr[i], - platform::errors::InvalidArgument( - "Input(Param) and Output(ParamOut) must be the same Tensors.")); - PADDLE_ENFORCE_EQ(velocity[i]->data(), - lars_warpper.v_out_arr[i], - platform::errors::InvalidArgument( - "Input(Velocity) and Output(VelocityOut) must be " - "the same Tensors.")); - } - int64_t avg_numel = total_numel / op_num; - LarsThreadConfig lars_thread_config( - avg_numel, sm_num, num_blocks_per_sm); + if (op_num > 1) { + LarsParamWarpper lars_warpper; + PADDLE_ENFORCE_LT( + op_num, + LARS_MAX_MERGED_OPS, + errors::InvalidArgument( + "The maximum number of merged-ops supported is (%d), but" + "lars op required for trainning this model is (%d)\n", + LARS_MAX_MERGED_OPS, + op_num)); + + /* Implementation of lars optimizer consists of following two steps: + 1. Figure out the L2 norm statistic result of grad data and param data. + 2. Update param and velocity with usage of L2 norm statistic result. + Step1 and step2 can be merged with api provided by nvida + cudaLaunchCooperativeKernel: + - The thread quantity shall less than pyhsical SM limited threads + - Launche as thread-block can synchronizlly execute. */ + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, + MergedMomentumLarsKernel, + LARS_BLOCK_SIZE, + sizeof(MT) << 1); + + size_t total_numel = 0; + for (int i = 0; i < op_num; ++i) { + size_t temp_numel = param[i]->numel(); + total_numel += temp_numel; + lars_warpper.numel_arr[i] = temp_numel; + lars_warpper.g_arr[i] = grad[i]->data(); + lars_warpper.lr_arr[i] = learning_rate[i]->data(); + lars_warpper.p_out_arr[i] = dev_ctx.template Alloc(param_out[i]); + lars_warpper.v_out_arr[i] = dev_ctx.template Alloc(velocity_out[i]); + lars_warpper.weight_decay_arr[i] = static_cast(weight_decay_arr[i]); + PADDLE_ENFORCE_EQ( + param[i]->data(), + lars_warpper.p_out_arr[i], + errors::InvalidArgument( + "Input(Param) and Output(ParamOut) must be the same Tensors.")); + PADDLE_ENFORCE_EQ(velocity[i]->data(), + lars_warpper.v_out_arr[i], + errors::InvalidArgument( + "Input(Velocity) and Output(VelocityOut) must be " + "the same Tensors.")); + } + int64_t avg_numel = total_numel / op_num; + LarsThreadConfig lars_thread_config( + avg_numel, sm_num, num_blocks_per_sm); + for (int i = 0; i < op_num; ++i) { + lars_warpper.repeat_arr[i] = + lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); + } + if (multi_precision) { for (int i = 0; i < op_num; ++i) { - lars_warpper.repeat_arr[i] = - lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); - } - if (multi_precision) { - for (int i = 0; i < op_num; ++i) { - lars_warpper.master_p_out_arr[i] = - master_param_out[i]->mutable_data(ctx.GetPlace()); - PADDLE_ENFORCE_EQ(master_param[i]->data(), - lars_warpper.master_p_out_arr[i], - platform::errors::InvalidArgument( - "Input(MasterParam) and Output(MasterParamOut) " - "must be the same Tensors.")); - } + lars_warpper.master_p_out_arr[i] = + dev_ctx.template Alloc(master_param_out[i]); + PADDLE_ENFORCE_EQ(master_param.get()[i]->data(), + lars_warpper.master_p_out_arr[i], + errors::InvalidArgument( + "Input(MasterParam) and Output(MasterParamOut) " + "must be the same Tensors.")); } - void* cuda_param[] = {reinterpret_cast(&lars_warpper), - reinterpret_cast(&p_buffer), - reinterpret_cast(&g_buffer), - reinterpret_cast(&op_num), - reinterpret_cast(&mu), - reinterpret_cast(&lars_coeff), - reinterpret_cast(&epsilon), - reinterpret_cast(&rescale_grad), - reinterpret_cast(&multi_precision)}; - // Lanuch all sm theads, and thead of each block synchronizedly cooperate. - cudaLaunchCooperativeKernel( - reinterpret_cast(MergedMomentumLarsKernel), - lars_thread_config.grid_for_lars, - LARS_BLOCK_SIZE, - cuda_param, - 0, - cuda_ctx.stream()); - } else { - auto* param_data = param[0]->data(); - auto* grad_data = grad[0]->data(); - auto* velocity_data = velocity[0]->data(); - auto* lr = learning_rate[0]->data(); - auto* param_out_data = param_out[0]->mutable_data(ctx.GetPlace()); - auto* velocity_out_data = - velocity_out[0]->mutable_data(ctx.GetPlace()); - const MT* master_param_data = - multi_precision ? master_param[0]->data() : nullptr; - MT* master_param_out_data = - multi_precision - ? master_param_out[0]->mutable_data(ctx.GetPlace()) - : nullptr; - int64_t numel = param[0]->numel(); - MT lars_weight_decay = weight_decay_arr[0]; - - // Figure out how many blocks can be active in each sm. - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, - MomentumLarsKernel, - LARS_BLOCK_SIZE, - sizeof(MT) << 1); - LarsThreadConfig lars_thread_config( - numel, sm_num, num_blocks_per_sm); - int repeat_times = lars_thread_config.GetRepeatTimes(numel); - int thresh = 0; - void* cuda_param[] = { - reinterpret_cast(¶m_data), - reinterpret_cast(&grad_data), - reinterpret_cast(&velocity_data), - reinterpret_cast(¶m_out_data), - reinterpret_cast(&velocity_out_data), - reinterpret_cast(&master_param_data), - reinterpret_cast(&master_param_out_data), - reinterpret_cast(&lr), - reinterpret_cast(&p_buffer), - reinterpret_cast(&g_buffer), - reinterpret_cast(&mu), - reinterpret_cast(&lars_coeff), - reinterpret_cast(&lars_weight_decay), - reinterpret_cast(&epsilon), - reinterpret_cast(&rescale_grad), - reinterpret_cast(&repeat_times), - reinterpret_cast(&thresh), // Just a placeholder - reinterpret_cast(&numel), - reinterpret_cast(&multi_precision)}; - // Lanuch all sm theads. - cudaLaunchCooperativeKernel( - reinterpret_cast(MomentumLarsKernel), - lars_thread_config.grid_for_lars, - LARS_BLOCK_SIZE, - cuda_param, - 0, - cuda_ctx.stream()); } + void* cuda_param[] = {reinterpret_cast(&lars_warpper), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&op_num), + reinterpret_cast(&mu_), + reinterpret_cast(&lars_coeff_), + reinterpret_cast(&epsilon_), + reinterpret_cast(&rescale_grad_), + reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads, and thead of each block synchronizedly cooperate. + cudaLaunchCooperativeKernel( + reinterpret_cast(MergedMomentumLarsKernel), + lars_thread_config.grid_for_lars, + LARS_BLOCK_SIZE, + cuda_param, + 0, + dev_ctx.stream()); + } else { + auto* param_data = param[0]->data(); + auto* grad_data = grad[0]->data(); + auto* velocity_data = velocity[0]->data(); + auto* lr = learning_rate[0]->data(); + auto* param_out_data = dev_ctx.template Alloc(param_out[0]); + auto* velocity_out_data = dev_ctx.template Alloc(velocity_out[0]); + const MT* master_param_data = + multi_precision ? master_param.get()[0]->data() : nullptr; + MT* master_param_out_data = + multi_precision ? dev_ctx.template Alloc(master_param_out[0]) + : nullptr; + int64_t numel = param[0]->numel(); + MT lars_weight_decay = weight_decay_arr[0]; + + // Figure out how many blocks can be active in each sm. + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, + MomentumLarsKernel, + LARS_BLOCK_SIZE, + sizeof(MT) << 1); + LarsThreadConfig lars_thread_config( + numel, sm_num, num_blocks_per_sm); + int repeat_times = lars_thread_config.GetRepeatTimes(numel); + int thresh = 0; + void* cuda_param[] = { + reinterpret_cast(¶m_data), + reinterpret_cast(&grad_data), + reinterpret_cast(&velocity_data), + reinterpret_cast(¶m_out_data), + reinterpret_cast(&velocity_out_data), + reinterpret_cast(&master_param_data), + reinterpret_cast(&master_param_out_data), + reinterpret_cast(&lr), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&mu_), + reinterpret_cast(&lars_coeff_), + reinterpret_cast(&lars_weight_decay), + reinterpret_cast(&epsilon_), + reinterpret_cast(&rescale_grad_), + reinterpret_cast(&repeat_times), + reinterpret_cast(&thresh), // Just a placeholder + reinterpret_cast(&numel), + reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads. + cudaLaunchCooperativeKernel( + reinterpret_cast(MomentumLarsKernel), + lars_thread_config.grid_for_lars, + LARS_BLOCK_SIZE, + cuda_param, + 0, + dev_ctx.stream()); + } #else - for (int i = 0; i < op_num; ++i) { - const MT* master_param_data = - multi_precision ? master_param[i]->data() : nullptr; - MT* master_param_out_data = - multi_precision - ? master_param_out[i]->mutable_data(ctx.GetPlace()) - : nullptr; - SeparatedLarsMomentumOpCUDAKernel( - cuda_ctx, - param[i]->data(), - param_out[i]->mutable_data(ctx.GetPlace()), - velocity[i]->data(), - velocity_out[i]->mutable_data(ctx.GetPlace()), - grad[i]->data(), - learning_rate[i]->data(), - p_buffer, - g_buffer, - mu, - lars_coeff, - weight_decay_arr[i], - epsilon, - rescale_grad, - param[i]->numel(), - master_param_data, - master_param_out_data, - multi_precision); - } + for (int i = 0; i < op_num; ++i) { + const MT* master_param_data = + multi_precision ? master_param.get()[i]->data() : nullptr; + MT* master_param_out_data = + multi_precision ? dev_ctx.template Alloc(master_param_out[i]) + : nullptr; + SeparatedLarsMomentumOpCUDAKernel( + dev_ctx, + param[i]->data(), + dev_ctx.template Alloc(param_out[i]), + velocity[i]->data(), + dev_ctx.template Alloc(velocity_out[i]), + grad[i]->data(), + learning_rate[i]->data(), + p_buffer, + g_buffer, + mu_, + lars_coeff_, + weight_decay_arr[i], + epsilon_, + rescale_grad_, + param[i]->numel(), + master_param_data, + master_param_out_data, + multi_precision); + } #endif +} +} // namespace phi + +PD_REGISTER_KERNEL(lars_momentum, + GPU, + ALL_LAYOUT, + phi::LarsMomentumKernel, + float, + double, + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(lars_momentum, - GPU, - ALL_LAYOUT, - ops::LarsMomentumOpCUDAKernel, - float, - double, - plat::float16) {} +} diff --git a/paddle/phi/kernels/lars_momentum_kernel.h b/paddle/phi/kernels/lars_momentum_kernel.h new file mode 100644 index 00000000000..2c1981762ae --- /dev/null +++ b/paddle/phi/kernels/lars_momentum_kernel.h @@ -0,0 +1,40 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LarsMomentumKernel( + const Context& dev_ctx, + const std::vector& param, + const std::vector& velocity, + const std::vector& learning_rate, + const std::vector& grad, + const paddle::optional>& master_param, + const std::vector& weight_decay_arr, + float mu, + float lars_coeff, + float epsilon, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/lars_momentum_sig.cc b/paddle/phi/ops/compat/lars_momentum_sig.cc new file mode 100644 index 00000000000..031677de12f --- /dev/null +++ b/paddle/phi/ops/compat/lars_momentum_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature LarsMomentumOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "lars_momentum", + {"Param", "Velocity", "LearningRate", "Grad", "MasterParam"}, + {"lars_weight_decay", + "mu", + "lars_coeff", + "epsilon", + "multi_precision", + "rescale_grad"}, + {"ParamOut", "VelocityOut", "MasterParamOut"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(lars_momentum, phi::LarsMomentumOpArgumentMapping); diff --git a/test/legacy_test/test_momentum_op.py b/test/legacy_test/test_momentum_op.py index 23dbab84e78..b23183996c0 100644 --- a/test/legacy_test/test_momentum_op.py +++ b/test/legacy_test/test_momentum_op.py @@ -312,7 +312,7 @@ class TestLarsMomentumOp(OpTest): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_dygraph=False) def config(self): self.params_num = 1 -- GitLab