未验证 提交 6b0c57cf 编写于 作者: Z zhangbo9674 提交者: GitHub

Fix master weight bug for multi_tensor optimizer(momentum, adam) (#38991)

* fix mp

* support merged_momentum for mp
上级 c0f27282
...@@ -48,13 +48,13 @@ struct MergedMomentumKernelParam ...@@ -48,13 +48,13 @@ struct MergedMomentumKernelParam
T *PADDLE_RESTRICT params[N]; T *PADDLE_RESTRICT params[N];
const T *PADDLE_RESTRICT grads[N]; const T *PADDLE_RESTRICT grads[N];
MT *PADDLE_RESTRICT velocitys[N]; MT *PADDLE_RESTRICT velocitys[N];
const MT *PADDLE_RESTRICT lr; const MultiPrecisionType<MT> *PADDLE_RESTRICT lr;
MT mu; MT mu;
MT rescale_grad; MT rescale_grad;
uint32_t param_num; uint32_t param_num;
HOSTDEVICE void operator()(size_t i) const { HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr; const MT lr_val = static_cast<MT>(*lr);
for (uint32_t idx = 0; idx < param_num; ++idx) { for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx]; auto size = sizes[idx];
if (i >= size) continue; if (i >= size) continue;
...@@ -81,8 +81,22 @@ struct MergedMomentumKernelParam ...@@ -81,8 +81,22 @@ struct MergedMomentumKernelParam
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> { class MergedMomentumOpKernel : public framework::OpKernel<T> {
using MPType = typename operators::details::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}
private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext &ctx,
const bool multi_precision) const {
auto params = ctx.MultiInput<framework::Tensor>("Param"); auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut"); auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size(); size_t n = params.size();
...@@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam"); auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out = auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut"); ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) { if (multi_precision) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
n, master_params.size(), n, master_params.size(),
...@@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
<< ", regularization_coeffs.size(): " << ", regularization_coeffs.size(): "
<< regularization_coeffs.size(); << regularization_coeffs.size();
using MPType = typename operators::details::MPTypeTrait<T>::Type;
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (lrs.size() == 1 && use_nesterov == false && if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) { regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \ #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \ MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \ constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \ size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \ kernel_params.mu = static_cast<MT>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \ kernel_params.rescale_grad = static_cast<MT>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \ kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \ for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \ size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \ size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \ kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \ size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \ for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \ auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \ max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \ kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \ kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \ kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \ kernel_params.velocitys[j] = velocitys_out[j + start]->data<MT>(); \
kernel_params.SetMasterParam( \ kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \ j, kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \ : nullptr); \
} \ } \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \ platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \ for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \ VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \ << kernel_params.param_num; \
} }
if (multi_precision) { if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true); PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
...@@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> { ...@@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
? RegularizationType::kL2DECAY ? RegularizationType::kL2DECAY
: RegularizationType::kNONE; : RegularizationType::kNONE;
MPType regularization_coeff = static_cast<MPType>(0.0); MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) { if (regularization_coeffs.size() != 0) {
regularization_coeff = regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
static_cast<MPType>(regularization_coeffs[idx]);
} }
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0]; auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
const MPType *master_in_data = const MT *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr; multi_precision ? master_params[idx]->data<MT>() : nullptr;
MPType *master_out_data = MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr; multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor; CPUDenseMomentumFunctor<MT> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu, functor(params[idx], grads[idx], velocitys[idx], lr_temp,
use_nesterov, regularization_flag, regularization_coeff, static_cast<MT>(mu), use_nesterov, regularization_flag,
params_out[idx], velocitys_out[idx]); regularization_coeff, params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel."; VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) { } else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()), static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel()); params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \ #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \ DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \ params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \ velocitys[idx]->data<MT>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \ static_cast<MT>(mu), static_cast<MT>(rescale_grad), \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \ params[idx]->numel(), regularization_coeff, params_out[idx]->data<T>(), \
master_out_data); \ velocitys_out[idx]->data<MT>(), master_out_data); \
for_range(functor); for_range(functor);
if (use_nesterov) { if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) { if (regularization_flag == RegularizationType::kL2DECAY) {
......
...@@ -551,8 +551,7 @@ class Adam(Optimizer): ...@@ -551,8 +551,7 @@ class Adam(Optimizer):
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list: for key in multi_tensor_list:
if len(self._param_dict[key]) > 0: if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor': find_master = self._multi_precision and key == 'FP16_LODTensor'
self._multi_precision = False
_beta1 = self._beta1 if not isinstance( _beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0) self._beta1, Variable) else self._beta1.numpy().item(0)
...@@ -571,7 +570,7 @@ class Adam(Optimizer): ...@@ -571,7 +570,7 @@ class Adam(Optimizer):
self._beta2_pow_acc_dict[key], self._beta2_pow_acc_dict[key],
self._master_weight_dict[key], 'epsilon', self._epsilon, self._master_weight_dict[key], 'epsilon', self._epsilon,
'beta1', _beta1, 'beta2', _beta2, 'multi_precision', 'beta1', _beta1, 'beta2', _beta2, 'multi_precision',
self._multi_precision) find_master)
else: else:
inputs = { inputs = {
"Param": self._param_dict[key], "Param": self._param_dict[key],
...@@ -594,11 +593,11 @@ class Adam(Optimizer): ...@@ -594,11 +593,11 @@ class Adam(Optimizer):
"beta1": _beta1, "beta1": _beta1,
"beta2": _beta2 "beta2": _beta2
} }
if self._multi_precision: if find_master:
inputs["MasterParam"] = self._master_weight_dict[key] inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[ outputs["MasterParamOut"] = self._master_weight_dict[
key] key]
attrs["multi_precision"] = self._multi_precision attrs["multi_precision"] = find_master
target_block.append_op( target_block.append_op(
type="merged_adam", type="merged_adam",
inputs=inputs, inputs=inputs,
......
...@@ -464,8 +464,7 @@ class Momentum(Optimizer): ...@@ -464,8 +464,7 @@ class Momentum(Optimizer):
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list: for key in multi_tensor_list:
if len(self._param_dict[key]) > 0: if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor': find_master = self._multi_precision and key == 'FP16_LODTensor'
self._multi_precision = False
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_, _, _ = _C_ops.merged_momentum( _, _, _ = _C_ops.merged_momentum(
...@@ -478,7 +477,7 @@ class Momentum(Optimizer): ...@@ -478,7 +477,7 @@ class Momentum(Optimizer):
self._regularization_method_dict[key], self._regularization_method_dict[key],
'regularization_coeff', 'regularization_coeff',
self._regularization_coeff_dict[key], 'multi_precision', self._regularization_coeff_dict[key], 'multi_precision',
self._multi_precision) find_master)
else: else:
inputs = { inputs = {
"Param": self._param_dict[key], "Param": self._param_dict[key],
...@@ -498,11 +497,11 @@ class Momentum(Optimizer): ...@@ -498,11 +497,11 @@ class Momentum(Optimizer):
"regularization_coeff": "regularization_coeff":
self._regularization_coeff_dict[key], self._regularization_coeff_dict[key],
} }
if self._multi_precision: if find_master:
inputs["MasterParam"] = self._master_weight_dict[key] inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[ outputs["MasterParamOut"] = self._master_weight_dict[
key] key]
attrs["multi_precision"] = self._multi_precision attrs["multi_precision"] = find_master
target_block.append_op( target_block.append_op(
type="merged_momentum", type="merged_momentum",
inputs=inputs, inputs=inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册