未验证 提交 6b0c57cf 编写于 作者: Z zhangbo9674 提交者: GitHub

Fix master weight bug for multi_tensor optimizer(momentum, adam) (#38991)

* fix mp

* support merged_momentum for mp
上级 c0f27282
......@@ -48,13 +48,13 @@ struct MergedMomentumKernelParam
T *PADDLE_RESTRICT params[N];
const T *PADDLE_RESTRICT grads[N];
MT *PADDLE_RESTRICT velocitys[N];
const MT *PADDLE_RESTRICT lr;
const MultiPrecisionType<MT> *PADDLE_RESTRICT lr;
MT mu;
MT rescale_grad;
uint32_t param_num;
HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr;
const MT lr_val = static_cast<MT>(*lr);
for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx];
if (i >= size) continue;
......@@ -81,8 +81,22 @@ struct MergedMomentumKernelParam
template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
using MPType = typename operators::details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}
private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext &ctx,
const bool multi_precision) const {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
......@@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
......@@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();
using MPType = typename operators::details::MPTypeTrait<T>::Type;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MT>(mu); \
kernel_params.rescale_grad = static_cast<MT>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MT>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
......@@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;
MPType regularization_coeff = static_cast<MPType>(0.0);
MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];
const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
const MT *master_in_data =
multi_precision ? master_params[idx]->data<MT>() : nullptr;
MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
CPUDenseMomentumFunctor<MT> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp,
static_cast<MT>(mu), use_nesterov, regularization_flag,
regularization_coeff, params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), lr_temp->data<MPType>(), master_in_data, \
static_cast<MT>(mu), static_cast<MT>(rescale_grad), \
params[idx]->numel(), regularization_coeff, params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
......
......@@ -551,8 +551,7 @@ class Adam(Optimizer):
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
find_master = self._multi_precision and key == 'FP16_LODTensor'
_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
......@@ -571,7 +570,7 @@ class Adam(Optimizer):
self._beta2_pow_acc_dict[key],
self._master_weight_dict[key], 'epsilon', self._epsilon,
'beta1', _beta1, 'beta2', _beta2, 'multi_precision',
self._multi_precision)
find_master)
else:
inputs = {
"Param": self._param_dict[key],
......@@ -594,11 +593,11 @@ class Adam(Optimizer):
"beta1": _beta1,
"beta2": _beta2
}
if self._multi_precision:
if find_master:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
attrs["multi_precision"] = find_master
target_block.append_op(
type="merged_adam",
inputs=inputs,
......
......@@ -464,8 +464,7 @@ class Momentum(Optimizer):
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
find_master = self._multi_precision and key == 'FP16_LODTensor'
if framework.in_dygraph_mode():
_, _, _ = _C_ops.merged_momentum(
......@@ -478,7 +477,7 @@ class Momentum(Optimizer):
self._regularization_method_dict[key],
'regularization_coeff',
self._regularization_coeff_dict[key], 'multi_precision',
self._multi_precision)
find_master)
else:
inputs = {
"Param": self._param_dict[key],
......@@ -498,11 +497,11 @@ class Momentum(Optimizer):
"regularization_coeff":
self._regularization_coeff_dict[key],
}
if self._multi_precision:
if find_master:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
attrs["multi_precision"] = find_master
target_block.append_op(
type="merged_momentum",
inputs=inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册