From 6b0c57cf65945e97d87a8fba89c0a2fc18dd8544 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 20 Jan 2022 11:53:25 +0800 Subject: [PATCH] Fix master weight bug for multi_tensor optimizer(momentum, adam) (#38991) * fix mp * support merged_momentum for mp --- .../operators/optimizers/merged_momentum_op.h | 110 ++++++++++-------- python/paddle/optimizer/adam.py | 9 +- python/paddle/optimizer/momentum.py | 9 +- 3 files changed, 68 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/optimizers/merged_momentum_op.h b/paddle/fluid/operators/optimizers/merged_momentum_op.h index 7560b4fd8e..c1ac2e366f 100644 --- a/paddle/fluid/operators/optimizers/merged_momentum_op.h +++ b/paddle/fluid/operators/optimizers/merged_momentum_op.h @@ -48,13 +48,13 @@ struct MergedMomentumKernelParam T *PADDLE_RESTRICT params[N]; const T *PADDLE_RESTRICT grads[N]; MT *PADDLE_RESTRICT velocitys[N]; - const MT *PADDLE_RESTRICT lr; + const MultiPrecisionType *PADDLE_RESTRICT lr; MT mu; MT rescale_grad; uint32_t param_num; HOSTDEVICE void operator()(size_t i) const { - const auto lr_val = *lr; + const MT lr_val = static_cast(*lr); for (uint32_t idx = 0; idx < param_num; ++idx) { auto size = sizes[idx]; if (i >= size) continue; @@ -81,8 +81,22 @@ struct MergedMomentumKernelParam template class MergedMomentumOpKernel : public framework::OpKernel { + using MPType = typename operators::details::MPTypeTrait::Type; + public: void Compute(const framework::ExecutionContext &ctx) const override { + const bool multi_precision = ctx.Attr("multi_precision"); + if (multi_precision) { + InnerCompute(ctx, multi_precision); + } else { + InnerCompute(ctx, multi_precision); + } + } + + private: + template + void InnerCompute(const framework::ExecutionContext &ctx, + const bool multi_precision) const { auto params = ctx.MultiInput("Param"); auto params_out = ctx.MultiOutput("ParamOut"); size_t n = params.size(); @@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel { auto master_params = ctx.MultiInput("MasterParam"); auto master_params_out = ctx.MultiOutput("MasterParamOut"); - auto multi_precision = ctx.Attr("multi_precision"); if (multi_precision) { PADDLE_ENFORCE_EQ( n, master_params.size(), @@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel { << ", regularization_coeffs.size(): " << regularization_coeffs.size(); - using MPType = typename operators::details::MPTypeTrait::Type; - auto &dev_ctx = ctx.template device_context(); if (lrs.size() == 1 && use_nesterov == false && regularization_methods.size() == 0) { -#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ - MergedMomentumKernelParam kernel_params; \ - constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ - size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ - kernel_params.mu = static_cast(mu); \ - kernel_params.rescale_grad = static_cast(rescale_grad); \ - kernel_params.lr = lrs[0]->data(); \ - for (size_t i = 0; i < kernel_num; ++i) { \ - size_t start = i * kMaxMergedNum; \ - size_t end = std::min((i + 1) * kMaxMergedNum, n); \ - kernel_params.param_num = static_cast(end - start); \ - size_t max_size = 0; \ - for (size_t j = 0; j < kernel_params.param_num; ++j) { \ - auto size = static_cast(params_out[j + start]->numel()); \ - max_size = std::max(max_size, size); \ - kernel_params.sizes[j] = size; \ - kernel_params.params[j] = params_out[j + start]->data(); \ - kernel_params.grads[j] = grads[j + start]->data(); \ - kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ - kernel_params.SetMasterParam( \ - j, kMultiPrecision ? master_params_out[j + start]->data() \ - : nullptr); \ - } \ - platform::ForRange for_range(dev_ctx, max_size); \ - for_range(kernel_params); \ - VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ - << kernel_params.param_num; \ +#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ + MergedMomentumKernelParam kernel_params; \ + constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ + size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ + kernel_params.mu = static_cast(mu); \ + kernel_params.rescale_grad = static_cast(rescale_grad); \ + kernel_params.lr = lrs[0]->data(); \ + for (size_t i = 0; i < kernel_num; ++i) { \ + size_t start = i * kMaxMergedNum; \ + size_t end = std::min((i + 1) * kMaxMergedNum, n); \ + kernel_params.param_num = static_cast(end - start); \ + size_t max_size = 0; \ + for (size_t j = 0; j < kernel_params.param_num; ++j) { \ + auto size = static_cast(params_out[j + start]->numel()); \ + max_size = std::max(max_size, size); \ + kernel_params.sizes[j] = size; \ + kernel_params.params[j] = params_out[j + start]->data(); \ + kernel_params.grads[j] = grads[j + start]->data(); \ + kernel_params.velocitys[j] = velocitys_out[j + start]->data(); \ + kernel_params.SetMasterParam( \ + j, kMultiPrecision ? master_params_out[j + start]->data() \ + : nullptr); \ + } \ + platform::ForRange for_range(dev_ctx, max_size); \ + for_range(kernel_params); \ + VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ + << kernel_params.param_num; \ } if (multi_precision) { PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); @@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel { ? RegularizationType::kL2DECAY : RegularizationType::kNONE; - MPType regularization_coeff = static_cast(0.0); + MT regularization_coeff = static_cast(0.0); if (regularization_coeffs.size() != 0) { - regularization_coeff = - static_cast(regularization_coeffs[idx]); + regularization_coeff = static_cast(regularization_coeffs[idx]); } auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; - const MPType *master_in_data = - multi_precision ? master_params[idx]->data() : nullptr; - MPType *master_out_data = - multi_precision ? master_params_out[idx]->data() : nullptr; + const MT *master_in_data = + multi_precision ? master_params[idx]->data() : nullptr; + MT *master_out_data = + multi_precision ? master_params_out[idx]->data() : nullptr; if (platform::is_cpu_place(ctx.GetPlace())) { - CPUDenseMomentumFunctor functor; - functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu, - use_nesterov, regularization_flag, regularization_coeff, - params_out[idx], velocitys_out[idx]); + CPUDenseMomentumFunctor functor; + functor(params[idx], grads[idx], velocitys[idx], lr_temp, + static_cast(mu), use_nesterov, regularization_flag, + regularization_coeff, params_out[idx], velocitys_out[idx]); VLOG(10) << "Launch MergedMomentum cpu kernel."; } else if (platform::is_gpu_place(ctx.GetPlace())) { platform::ForRange for_range( static_cast(ctx.device_context()), params[idx]->numel()); -#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ - DenseMomentumFunctor functor( \ - params[idx]->data(), grads[idx]->data(), \ - velocitys[idx]->data(), lr_temp->data(), master_in_data, \ - mu, rescale_grad, params[idx]->numel(), regularization_coeff, \ - params_out[idx]->data(), velocitys_out[idx]->data(), \ - master_out_data); \ +#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ + DenseMomentumFunctor functor( \ + params[idx]->data(), grads[idx]->data(), \ + velocitys[idx]->data(), lr_temp->data(), master_in_data, \ + static_cast(mu), static_cast(rescale_grad), \ + params[idx]->numel(), regularization_coeff, params_out[idx]->data(), \ + velocitys_out[idx]->data(), master_out_data); \ for_range(functor); if (use_nesterov) { if (regularization_flag == RegularizationType::kL2DECAY) { diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 8134c9f71b..bbed2bed1d 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -551,8 +551,7 @@ class Adam(Optimizer): multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] for key in multi_tensor_list: if len(self._param_dict[key]) > 0: - if key == 'FP32_LODTensor': - self._multi_precision = False + find_master = self._multi_precision and key == 'FP16_LODTensor' _beta1 = self._beta1 if not isinstance( self._beta1, Variable) else self._beta1.numpy().item(0) @@ -571,7 +570,7 @@ class Adam(Optimizer): self._beta2_pow_acc_dict[key], self._master_weight_dict[key], 'epsilon', self._epsilon, 'beta1', _beta1, 'beta2', _beta2, 'multi_precision', - self._multi_precision) + find_master) else: inputs = { "Param": self._param_dict[key], @@ -594,11 +593,11 @@ class Adam(Optimizer): "beta1": _beta1, "beta2": _beta2 } - if self._multi_precision: + if find_master: inputs["MasterParam"] = self._master_weight_dict[key] outputs["MasterParamOut"] = self._master_weight_dict[ key] - attrs["multi_precision"] = self._multi_precision + attrs["multi_precision"] = find_master target_block.append_op( type="merged_adam", inputs=inputs, diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index ada6b06eb6..12d9fb997b 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -464,8 +464,7 @@ class Momentum(Optimizer): multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] for key in multi_tensor_list: if len(self._param_dict[key]) > 0: - if key == 'FP32_LODTensor': - self._multi_precision = False + find_master = self._multi_precision and key == 'FP16_LODTensor' if framework.in_dygraph_mode(): _, _, _ = _C_ops.merged_momentum( @@ -478,7 +477,7 @@ class Momentum(Optimizer): self._regularization_method_dict[key], 'regularization_coeff', self._regularization_coeff_dict[key], 'multi_precision', - self._multi_precision) + find_master) else: inputs = { "Param": self._param_dict[key], @@ -498,11 +497,11 @@ class Momentum(Optimizer): "regularization_coeff": self._regularization_coeff_dict[key], } - if self._multi_precision: + if find_master: inputs["MasterParam"] = self._master_weight_dict[key] outputs["MasterParamOut"] = self._master_weight_dict[ key] - attrs["multi_precision"] = self._multi_precision + attrs["multi_precision"] = find_master target_block.append_op( type="merged_momentum", inputs=inputs, -- GitLab