// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/macros.h" namespace paddle { namespace operators { template struct MergedMomentumMasterParams { MT *PADDLE_RESTRICT master_params[kParamNum]; HOSTDEVICE MT *MasterParam(size_t idx) const { return master_params[idx]; } HOSTDEVICE void SetMasterParam(size_t idx, MT *p) { master_params[idx] = p; } }; template struct MergedMomentumMasterParams { HOSTDEVICE constexpr MT *MasterParam(size_t) const { return nullptr; } HOSTDEVICE constexpr void SetMasterParam(size_t, MT *) {} }; template struct MergedMomentumKernelParam : public MergedMomentumMasterParams { static constexpr auto N = kParamNum; size_t sizes[N]; T *PADDLE_RESTRICT params[N]; const T *PADDLE_RESTRICT grads[N]; MT *PADDLE_RESTRICT velocitys[N]; const MT *PADDLE_RESTRICT lr; MT mu; MT rescale_grad; uint32_t param_num; HOSTDEVICE void operator()(size_t i) const { const auto lr_val = *lr; for (uint32_t idx = 0; idx < param_num; ++idx) { auto size = sizes[idx]; if (i >= size) continue; auto param_p = params[idx]; auto grad_p = grads[idx]; auto velocity_p = velocitys[idx]; auto master_param_p = this->MasterParam(idx); const MT param = master_param_p ? master_param_p[i] : static_cast(param_p[i]); const MT grad = static_cast(grad_p[i]) * rescale_grad; const MT velocity = velocity_p[i]; const MT velocity_out = velocity * mu + grad; const MT param_out = param - lr_val * velocity_out; velocity_p[i] = velocity_out; param_p[i] = static_cast(param_out); if (master_param_p) { master_param_p[i] = param_out; } } } }; template class MergedMomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto params = ctx.MultiInput("Param"); auto params_out = ctx.MultiOutput("ParamOut"); size_t n = params.size(); PADDLE_ENFORCE_EQ( n, params_out.size(), platform::errors::InvalidArgument( "Output(ParamOut) number must be equal to Input(Param) number.")); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ( params[i], params_out[i], platform::errors::InvalidArgument( "Input(Param) and Output(ParamOut) must be the same Tensors.")); } auto grads = ctx.MultiInput("Grad"); PADDLE_ENFORCE_EQ( n, grads.size(), platform::errors::InvalidArgument( "Input(Grad) number must be equal to Input(Param) number.")); auto velocitys = ctx.MultiInput("Velocity"); PADDLE_ENFORCE_EQ(n, velocitys.size(), platform::errors::InvalidArgument( "Input(Velocity) number and Input(Param) number.")); auto velocitys_out = ctx.MultiOutput("VelocityOut"); PADDLE_ENFORCE_EQ( n, velocitys_out.size(), platform::errors::InvalidArgument("Output(VelocityOut) number must be " "equal to Input(Param) number.")); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i], platform::errors::InvalidArgument( "Input(Velocity) and Output(VelocityOut) must be " "the same Tensors.")); } auto master_params = ctx.MultiInput("MasterParam"); auto master_params_out = ctx.MultiOutput("MasterParamOut"); auto multi_precision = ctx.Attr("multi_precision"); if (multi_precision) { PADDLE_ENFORCE_EQ( n, master_params.size(), platform::errors::InvalidArgument("Input(MasterParam) number must be " "equal to Input(Param) number.")); PADDLE_ENFORCE_EQ(n, master_params_out.size(), platform::errors::InvalidArgument( "Output(MasterParamOut) number must be equal to " "Input(MasterParam) number.")); for (size_t i = 0; i < n; ++i) { PADDLE_ENFORCE_EQ(master_params[i], master_params_out[i], platform::errors::InvalidArgument( "Input(MasterParam) and Output(MasterParamOut) " "must be the same Tensors.")); PADDLE_ENFORCE_NOT_NULL(master_params[i], platform::errors::InvalidArgument( "Input(MasterParam) must be provided when " "multi_precision=True.")); } } else { master_params.clear(); master_params_out.clear(); } auto lr = ctx.Input("LearningRate"); auto mu = ctx.Attr("mu"); auto rescale_grad = ctx.Attr("rescale_grad"); using MPType = typename operators::details::MPTypeTrait::Type; auto &dev_ctx = ctx.template device_context(); #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ MergedMomentumKernelParam kernel_params; \ constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ kernel_params.mu = static_cast(mu); \ kernel_params.rescale_grad = static_cast(rescale_grad); \ kernel_params.lr = lr->data(); \ for (size_t i = 0; i < kernel_num; ++i) { \ size_t start = i * kMaxMergedNum; \ size_t end = std::min((i + 1) * kMaxMergedNum, n); \ kernel_params.param_num = static_cast(end - start); \ size_t max_size = 0; \ for (size_t j = 0; j < kernel_params.param_num; ++j) { \ auto size = static_cast(params_out[j + start]->numel()); \ max_size = std::max(max_size, size); \ kernel_params.sizes[j] = size; \ kernel_params.params[j] = params_out[j + start]->data(); \ kernel_params.grads[j] = grads[j + start]->data(); \ kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ kernel_params.SetMasterParam( \ j, kMultiPrecision ? master_params_out[j + start]->data() \ : nullptr); \ } \ platform::ForRange for_range(dev_ctx, max_size); \ for_range(kernel_params); \ VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ << kernel_params.param_num; \ } if (multi_precision) { PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); } else { PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(false); } #undef PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL } }; } // namespace operators } // namespace paddle