未验证 提交 5df3cd61 编写于 作者: S sneaxiy 提交者: GitHub

Add the DistributedFusedLamb optimizer (#39148)

* add DistributedFusedLamb op

* polish code

* fix compile error

* compatible with pten changement

* fix rocm compile error

* improve converage

* update upstream/develop

* fix cast_with_ptr.h

* add FLAGS_distributed_lamb_divide_nranks_when_allreduce=1

* fix clip before allreduce

* add use_master_param_norm

* code polish

* fix bug

* fix ROCM ci
上级 7fc04070
......@@ -51,6 +51,8 @@ class Buffer {
size_t Size() const { return allocation_ ? 0 : allocation_->size(); }
platform::Place GetPlace() const { return place_; }
private:
AllocationPtr allocation_;
platform::Place place_;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
namespace paddle {
namespace operators {
namespace details {
template <typename InT, typename OutT>
struct CastFunctor {
HOSTDEVICE OutT operator()(InT x) const { return static_cast<OutT>(x); }
};
template <typename InT, typename OutT, int VecSize>
static void VecCastKernel(const platform::CUDADeviceContext &ctx, const InT *x,
OutT *y, size_t n) {
auto config = platform::GetGpuLaunchConfig1D(ctx, n, VecSize);
auto block = config.GetGridSize();
auto thread = config.GetBlockSize();
auto main_offset = n / (VecSize * thread) * VecSize * thread;
auto stream = ctx.stream();
using FunctorT = CastFunctor<InT, OutT>;
pten::framework::Array<const _ptr_ char *__restrict__, 1> in_arr;
in_arr[0] = reinterpret_cast<const _ptr_ char *>(x);
pten::framework::Array<_ptr_ OutT *, 1> out_arr;
out_arr[0] = y;
pten::funcs::VectorizedElementwiseKernel<
OutT, FunctorT, 1, 1, VecSize><<<block, thread, 0, stream>>>(
in_arr, out_arr, n, main_offset, FunctorT());
}
} // namespace details
template <typename InT, typename OutT>
static void LaunchCastKernel(const platform::CUDADeviceContext &ctx,
const InT *x, OutT *y, size_t n) {
if (n == 0) return;
PADDLE_ENFORCE_NE(
static_cast<const void *>(x), static_cast<void *>(y),
platform::errors::InvalidArgument("Inplace cast is not supported yet."));
int vec_size =
std::min(platform::GetVectorizedSize(x), platform::GetVectorizedSize(y));
switch (vec_size) {
case 4:
return details::VecCastKernel<InT, OutT, 4>(ctx, x, y, n);
case 2:
return details::VecCastKernel<InT, OutT, 2>(ctx, x, y, n);
case 1:
return details::VecCastKernel<InT, OutT, 1>(ctx, x, y, n);
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The vectorized size must be 1, 2 or 4."));
}
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h"
namespace paddle {
namespace operators {
class DistributedFusedLambInitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto dtype = framework::proto::VarType::FP32; // dtype is not important
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
class DistributedFusedLambInitOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "The initial parameter list.").AsDuplicable();
AddInput("Grad", "The initial gradient list.").AsDuplicable();
AddOutput("FP32FusedParam",
"The fp32 fused param and fp16 fused master weight tensor. Its "
"shape is [M1+M2], where M1 is the fp32 fused parameter size and "
"M2 is the fp16 fused master weight parameter size. Note that M1 "
"and M2 should be exactly divided by N (guaranteed by extra "
"padding 0), where N is the world size.")
.AsDispensable();
AddOutput("FP32FusedGrad", "The fp32 fused grad tensor. Its shape is [M1].")
.AsDispensable();
AddOutput("FP16FusedParam",
"The fp16 fused param tensor. Its shape is [M2].")
.AsDispensable();
AddOutput("FP16FusedGrad", "The fp16 fused grad tensor. Its shape is [M2].")
.AsDispensable();
AddOutput("Moment1",
"The sharded fp32 moment1 tensor. Its shape is [(M1+M2)/N].");
AddOutput("Moment2",
"The sharded fp32 moment2 tensor. Its shape is [(M1+M2)/N].");
AddOutput("Beta1Pow",
"The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddOutput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddOutput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddOutput(
"FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...].");
AddOutput("FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1].");
AddOutput("FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1].");
AddOutput(
"WeightDecay",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N].");
AddOutput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num].");
AddOutput("ParamOut", "The output parameter list.").AsDuplicable();
AddOutput("MasterParamOut",
"The output master parameter list. It would share the memory of "
"each fp32 parameter and fp16 master parameter.")
.AsDuplicable();
AddOutput("GradOut", "The output gradient list.").AsDuplicable();
AddOutput("GlobalScale",
"The global scale. It is usually the scale factor for AMP.");
AddAttr<float>("beta1", "The initial value of Beta1Pow.");
AddAttr<float>("beta2", "The initial value of Beta2Pow.");
AddAttr<std::vector<float>>(
"weight_decay",
"The weight decay for each parameter. Its "
"shape is equal to the global parameter number.");
AddAttr<int>("alignment", "The alignment in bytes for the fused tensors.");
AddAttr<int>("rank", "The global rank of the current process.");
AddAttr<int>("nranks", "The global world size.");
AddComment(
R"DOC(The init operator for the DistributedFusedLamb optimizer.)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb_init,
ops::DistributedFusedLambInitOp,
ops::DistributedFusedLambInitOpMaker);
REGISTER_OP_CPU_KERNEL(
distributed_fused_lamb_init,
ops::DistributedFusedLambInitOpKernel<plat::CPUDeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
template <typename DevCtx, typename T>
class DistributedFusedLambInitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"The distributed_fused_lamb_init operator does not support CPU yet."));
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
namespace paddle {
namespace operators {
class DistributedFusedLambOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto dtype = framework::proto::VarType::FP32; // dtype is not important
return framework::OpKernelType(dtype, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "ParamInfo") {
return expected_kernel_type;
} else {
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
}
};
class DistributedFusedLambOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "The initial parameter list.").AsDuplicable();
AddInput("Grad", "The initial gradient list.").AsDuplicable();
AddInput("FP32FusedParam",
"The fp32 fused param and fp16 fused master weight tensor. Its "
"shape is [M1+M2], where M1 is the fp32 fused parameter size and "
"M2 is the fp16 fused master weight parameter size. Note that M1 "
"and M2 should be exactly divided by N (guaranteed by extra "
"padding 0), where N is the world size.")
.AsDispensable();
AddInput("FP32FusedGrad", "The fp32 fused grad tensor. Its shape is [M1].")
.AsDispensable();
AddInput("FP16FusedParam",
"The fp16 fused param tensor. Its shape is [M2].")
.AsDispensable();
AddInput("FP16FusedGrad", "The fp16 fused grad tensor. Its shape is [M2].")
.AsDispensable();
AddInput("Moment1",
"The sharded fp32 moment1 tensor. Its shape is [(M1+M2)/N].");
AddInput("Moment2",
"The sharded fp32 moment2 tensor. Its shape is [(M1+M2)/N].");
AddInput("Beta1Pow",
"The fp32 beta1 power accumulator tensor. Its shape is [1].");
AddInput("Beta2Pow",
"The fp32 beta2 power accumulator tensor. Its shape is [1].");
AddInput("FusedIndices",
"The param index of each element in FP32FusedParam. Its shape is "
"[M1+M2]. It is like [0,0,0,1,1,1,1,2,2,...].");
AddInput(
"FusedParamOffsets",
"The numel offset of each parameter inside the FP32FusedParam. Its "
"shape is [param_num + 1]. It is like [0, n_0, n_0 + n_1, n_0 + n_1 "
"+ n_2, ...].");
AddInput("FP32ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp32_local_param_num + 1].");
AddInput("FP16ShardFusedParamOffsets",
"The sharded numel offset of each parameter in the local rank. "
"Its shape is [fp16_local_param_num + 1].");
AddInput("WeightDecay",
"The sharded fp32 weight decay tensor. Its shape is [(M1+M2)/N].");
AddInput("ParamInfo",
"The param info. It should be in CPUPlace, and its shape is [6]"
"CPUPlace, and its shape is [6]. It is "
"[fp32_shard_param_start_idx, fp32_local_param_num, "
"fp32_global_param_num, fp16_shard_param_start_idx, "
"fp16_local_param_num, fp16_global_param_num].");
AddInput("LearningRate",
"The fp32 learning rate tensor. Its shape is [1].");
AddInput("GlobalScale", "The fp32 global scale tensor. Its shape is [1].");
AddOutput("FP32FusedParamOut", "The updated FP32FusedParam.")
.AsDispensable();
AddOutput("FP16FusedParamOut", "The updated FP16FusedParam.")
.AsDispensable();
AddOutput("Moment1Out", "The updated Moment1.");
AddOutput("Moment2Out", "The updated Moment2.");
AddOutput("Beta1PowOut", "The updated Beta1Pow.");
AddOutput("Beta2PowOut", "The updated Beta2Pow.");
AddOutput("ParamOut", "The updated output parameter tensor list.")
.AsDuplicable();
AddOutput("FoundInf", "Whether there is NaN/Inf");
AddAttr<float>("beta1", "The initial Beta1Pow value.");
AddAttr<float>("beta2", "The initial Beta2Pow value.");
AddAttr<float>("epsilon",
"The epsilon value to maintain numeric stability.");
AddAttr<float>(
"max_global_grad_norm",
"The maximum global gradient l2-norm value for clipping. If "
"max_global_grad_norm <= 0, no clipping would be performed.");
AddAttr<bool>("clip_after_allreduce",
"Whether to clip before allreduce, only valid when the "
"world size is larger than 1.");
AddAttr<bool>(
"use_master_param_norm",
"Whether to use master parameter to calculate "
"the L2-Norm. If it is true, it would be more accurate but be more "
"NCCL communication data. If it is false, it would be less accurate "
"and be less NCCL communication data.")
.SetDefault(true);
AddAttr<bool>("is_grad_scaled_by_nranks",
"Whether the input gradient has been scaled by nranks.")
.SetDefault(true);
AddAttr<int>("ring_id", "The ring id of the NCCL communicator.")
.SetDefault(0);
AddComment("The DistributedFusedLamb optimizer.");
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker);
REGISTER_OP_CPU_KERNEL(
distributed_fused_lamb,
ops::DistributedFusedLambOpKernel<plat::CPUDeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename DevCtx, typename T>
class DistributedFusedLambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}
};
} // namespace operators
} // namespace paddle
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/squared_l2_norm.h"
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/pten/kernels/funcs/algorithm.h"
#include "paddle/pten/kernels/funcs/eigen/extensions.h"
......@@ -658,6 +659,16 @@ class LambOpKernel : public framework::OpKernel<T> {
math::SquaredL2Norm(dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr,
numel, &buffer);
if (VLOG_IS_ON(1)) {
const auto& name = ctx.GetOp().Input("Param");
auto pn = ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
auto tn = ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
auto dtype =
framework::DataTypeToString(framework::DataTypeTrait<T>::DataType());
VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
<< " , tn = " << tn[0];
}
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sstream>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace operators {
template <typename T>
static const std::vector<T> &ToVector(const std::vector<T> &vec) {
return vec;
}
template <typename T>
static std::vector<T> ToVector(const T *x, size_t n,
const platform::Place &place) {
#ifdef __NVCC__
if (platform::is_gpu_place(place)) {
using CopyT = typename std::conditional<std::is_same<T, bool>::value,
uint8_t, T>::type;
std::vector<CopyT> cpu_x(n);
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
memory::Copy(platform::CPUPlace(), cpu_x.data(), place, x, n * sizeof(T),
dev_ctx->stream());
dev_ctx->Wait();
return std::vector<T>(cpu_x.data(), cpu_x.data() + n);
}
#endif
return std::vector<T>(x, x + n);
}
template <typename T>
static std::vector<T> ToVector(const framework::Tensor &src) {
if (!src.IsInitialized()) {
return {};
}
return ToVector(src.template data<T>(), src.numel(), src.place());
}
template <typename... Args>
static std::string FlattenToString(Args &&... args) {
const auto &vec = ToVector(std::forward<Args>(args)...);
return "[" + string::join_strings(vec, ',') + "]";
}
} // namespace operators
} // namespace paddle
......@@ -36,12 +36,35 @@ __all__ = [
'ClipGradByNorm', 'ClipGradByGlobalNorm'
]
_clip_by_global_norm_using_mp_type_flag = False
def _clip_by_global_norm_using_mp_type(*args):
global _clip_by_global_norm_using_mp_type_flag
assert len(args) <= 1
if len(args) == 1:
assert isinstance(args[0], bool)
old_value = _clip_by_global_norm_using_mp_type_flag
_clip_by_global_norm_using_mp_type_flag = args[0]
return old_value
else:
return _clip_by_global_norm_using_mp_type_flag
def _cast_to_mp_type_if_enabled(x):
if x.dtype == core.VarDesc.VarType.FP16 and _clip_by_global_norm_using_mp_type(
):
return x.astype(core.VarDesc.VarType.FP32)
else:
return x
def _squared_l2_norm(x):
r"""
This OP returns the squared L2 norm of a tensor.
"""
x = _cast_to_mp_type_if_enabled(x)
if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16:
square = layers.square(x)
sum_square = layers.reduce_sum(square)
......@@ -595,9 +618,10 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue
with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (scale_var.astype('float16')
if g.dtype == core.VarDesc.VarType.FP16 and
scale_input = (scale_var.astype('float16') if
new_g.dtype == core.VarDesc.VarType.FP16 and
scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
......@@ -607,9 +631,18 @@ class ClipGradByGlobalNorm(ClipGradBase):
block = default_main_program().current_block()
block.append_op(
type='elementwise_mul',
inputs={'X': g,
inputs={'X': new_g,
'Y': scale_input},
outputs={'Out': g})
outputs={'Out': new_g})
if new_g is not g:
block.append_op(
type='cast',
inputs={'X': new_g},
outputs={'Out': g},
attrs={
'in_dtype': new_g.dtype,
'out_dtype': g.dtype
})
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
......
......@@ -108,6 +108,9 @@ class OptimizerWithMixedPrecision(object):
"""
return self._scaled_loss
def _supports_check_nan_inf(self):
return getattr(self._optimizer, "_supports_check_nan_inf", False)
def _init_amp_var(self):
self._loss_scaling = layers.create_global_var(
name=unique_name.generate("loss_scaling"),
......@@ -202,8 +205,34 @@ class OptimizerWithMixedPrecision(object):
params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
if self._supports_check_nan_inf():
self._add_cast_ops_to_startup_program(startup_program)
return params_grads
def _add_cast_ops_to_startup_program(self, startup_program):
names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
names.sort()
startup_program = default_startup_program(
) if startup_program is None else startup_program
block = startup_program.global_block()
param_names = [p.name for p in block.all_parameters()]
for name in names:
if name not in param_names:
continue
tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
block.append_op(
type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]})
block.append_op(
type='cast',
inputs={'X': [tmp]},
outputs={'Out': [name]},
attrs={
'in_dtype': core.VarDesc.VarType.FP32,
'out_dtype': core.VarDesc.VarType.FP16,
})
self._to_fp16_var_names = None
def amp_init(self,
place,
scope=None,
......@@ -297,13 +326,47 @@ class OptimizerWithMixedPrecision(object):
if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0:
return self._optimizer.apply_gradients(params_grads)
if self._supports_check_nan_inf():
self._optimizer._set_scale(self._loss_scaling)
optimize_ops = self._optimizer.apply_gradients(params_grads)
found_inf = self._optimizer._found_inf
self._add_dynamic_loss_scaling(params_grads, found_inf)
return optimize_ops
found_inf = self._check_finite_and_unscale(params_grads)
if self._use_dynamic_loss_scaling:
self._add_dynamic_loss_scaling(params_grads, found_inf)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
# With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
real_optimizer = self._optimizer
while hasattr(real_optimizer, "inner_opt"):
real_optimizer = real_optimizer.inner_opt
if isinstance(real_optimizer, (paddle.fluid.optimizer.Adam,
paddle.optimizer.AdamW)):
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# copy it in advance to avoid multiple time copies.
with self._train_program._optimized_guard([]):
found_inf = paddle.tensor.creation._memcpy(found_inf,
paddle.CPUPlace())
real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf)
optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops
def _split_grads(self, params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \
"Data types of all grads must be either fp16 or fp32."
return grads, fp32_grads, fp16_grads
def _check_finite_and_unscale(self, params_grads):
grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
found_infs = []
if self._is_distributed:
# if distributed, split check_finite_and_unscale to overlap
# unscale with communication
......@@ -349,46 +412,37 @@ class OptimizerWithMixedPrecision(object):
name="find_infinite_scale",
float_status=self._float_status)
if self._use_dynamic_loss_scaling:
if self._is_distributed or self._use_pure_fp16:
with self._train_program._optimized_guard([]):
all_infs = layers.concat(found_infs)
found_inf = layers.reduce_any(all_infs)
if self._is_distributed or self._use_pure_fp16:
with self._train_program._optimized_guard([]):
all_infs = layers.concat(found_infs)
found_inf = layers.reduce_any(all_infs)
if self._use_pure_fp16:
stop_update = False
with self._train_program._optimized_guard([]):
if fp32_grads:
update_loss_scaling(
fp32_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp32")
stop_update = True
if fp16_grads:
update_loss_scaling(
fp16_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp16")
else:
with self._train_program._optimized_guard([]):
return found_inf
def _add_dynamic_loss_scaling(self, params_grads, found_inf):
if self._supports_check_nan_inf():
with self._train_program._optimized_guard([]):
update_loss_scaling(
[],
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=False,
name="update_loss_scaling")
return
grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
if self._use_pure_fp16:
stop_update = False
with self._train_program._optimized_guard([]):
if fp32_grads:
update_loss_scaling(
grads,
fp32_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
......@@ -397,24 +451,35 @@ class OptimizerWithMixedPrecision(object):
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
# With fleet, optimizers are nested and the real optimizer set by user is the inner most one.
real_optimizer = self._optimizer
while hasattr(real_optimizer, "inner_opt"):
real_optimizer = real_optimizer.inner_opt
if isinstance(real_optimizer, (paddle.fluid.optimizer.Adam,
paddle.optimizer.AdamW)):
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# copy it in advance to avoid multiple time copies.
stop_update=stop_update,
name="update_loss_scaling_fp32")
stop_update = True
if fp16_grads:
update_loss_scaling(
fp16_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp16")
else:
with self._train_program._optimized_guard([]):
found_inf = paddle.tensor.creation._memcpy(found_inf,
paddle.CPUPlace())
real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf)
optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops
update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
def apply_optimize(self, loss, startup_program, params_grads):
program = loss.block.program
......
......@@ -846,6 +846,8 @@ set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inp
test_parallel_executor_seresnext_base_gpu
test_parallel_executor_seresnext_with_reduce_gpu
test_parallel_executor_seresnext_with_fuse_all_reduce_gpu
test_distributed_fused_lamb_op_with_clip
test_distributed_fused_lamb_op_without_clip
test_parallel_executor_fetch_isolated_var
PROPERTIES LABELS "RUN_TYPE=DIST")
......@@ -974,6 +976,8 @@ set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 120)
set_tests_properties(test_elementwise_sub_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120)
set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT 120)
set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120)
set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120)
set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle
import paddle.fluid.core as core
import paddle.distributed.fleet as fleet
from paddle.incubate import DistributedFusedLamb
from paddle.vision.models import resnet18 as resnet
from paddle.distributed.fleet.meta_optimizers.common import CollectiveHelper
from paddle.fluid.clip import ClipGradBase
import paddle.nn as nn
import numpy as np
import os
import unittest
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, is_backward_op
from paddle.fluid.clip import _clip_by_global_norm_using_mp_type
import distutils
def get_role_maker():
return fleet.PaddleCloudRoleMaker(is_collective=True)
def set_seed(seed):
paddle.seed(seed)
rank = paddle.distributed.get_rank()
np_seed = seed + rank
np.random.seed(np_seed)
def set_gradient_persistable(program):
block = program.global_block()
params = []
grads = []
for p in block.all_parameters():
p_name = p.name
g_name = p_name + '@GRAD'
g = block.vars.get(g_name)
if g is None:
continue
g.persistable = True
params.append(p)
grads.append(g)
return params, grads
def prune_fwd_bwd_ops(program, start_idx):
for i in reversed(range(start_idx)):
program.global_block()._remove_op(i, sync=False)
program._sync_with_cpp()
ops = program.global_block().ops
all_vars = set(program.global_block().vars.keys())
for op in ops:
args = op.input_arg_names + op.output_arg_names
for arg in args:
if arg in all_vars:
all_vars.remove(arg)
for var in all_vars:
program.global_block()._remove_var(var)
program._sync_with_cpp()
class GradClipDecorator(ClipGradBase):
def __init__(self, clip, clip_after_allreduce):
self.clip = clip
self.clip_after_allreduce = clip_after_allreduce
def _dygraph_clip(self, params_grads):
raise NotImplementedError()
def _insert_allreduce_ops(self, params_grads):
world_size = paddle.distributed.get_world_size()
if world_size == 1:
return
block = params_grads[0][0].block
scale = 1.0 / world_size
# scale = 1.0
for p, g in params_grads:
block.append_op(
type='c_allreduce_sum',
inputs={'X': [g]},
outputs={'Out': [g]},
attrs={'ring_id': 0,
'use_calc_stream': True})
block.append_op(
type='scale',
inputs={'X': [g]},
outputs={'Out': [g]},
attrs={'scale': scale})
def _static_clip(self, params_grads):
if self.clip_after_allreduce:
self._insert_allreduce_ops(params_grads)
params_grads = self.clip(params_grads)
if not self.clip_after_allreduce:
self._insert_allreduce_ops(params_grads)
return params_grads
class IdentityGradClip(ClipGradBase):
def _dygraph_clip(self, params_grads):
return params_grads
def _static_clip(self, params_grads):
return params_grads
def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs):
nranks = paddle.distributed.get_world_size()
set_seed(1000)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
with paddle.fluid.unique_name.guard():
with paddle.static.amp.fp16_guard():
image = paddle.static.data(
name='image',
shape=[None, 3, 224, 224],
dtype=paddle.float32)
label = paddle.static.data(
name='label', shape=[None, 1], dtype=paddle.int64)
model = resnet()
pred = model(image)
loss_fn = paddle.nn.loss.CrossEntropyLoss()
loss = loss_fn(pred, label)
grad_clip = kwargs.get('grad_clip', None)
clip_after_allreduce = kwargs.get('clip_after_allreduce', True)
if use_distributed_lamb:
optimizer_class = DistributedFusedLamb
kwargs = dict(kwargs)
kwargs['is_grad_scaled_by_nranks'] = False
kwargs['use_master_param_norm'] = use_master_param_norm
else:
optimizer_class = paddle.optimizer.Lamb
kwargs = dict(kwargs)
kwargs.pop('clip_after_allreduce', None)
kwargs.pop('alignment', None)
base_clip = grad_clip if grad_clip is not None else IdentityGradClip(
)
kwargs['grad_clip'] = GradClipDecorator(base_clip,
clip_after_allreduce)
optimizer = optimizer_class(**kwargs)
get_parameter = optimizer._get_parameter
amp_list = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=[
'batch_norm', 'batch_norm_grad', 'conv2d', 'conv2d_grad'
])
if use_fp16:
if not use_distributed_lamb:
optimizer._multi_precision = True
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=1.0,
use_dynamic_loss_scaling=False,
use_pure_fp16=use_fp16,
use_fp16_guard=use_fp16)
params_grads = optimizer.backward(loss, startup)
op_num = len(main.global_block().ops)
if use_fp16:
optimizer.apply_optimize(loss, startup, params_grads)
else:
optimizer.apply_gradients(params_grads)
if nranks > 1:
collective_helper = CollectiveHelper(role_maker=get_role_maker())
collective_helper.update_startup_program(startup)
set_gradient_persistable(startup)
params, grads = set_gradient_persistable(main)
prune_fwd_bwd_ops(main, op_num)
def pd_dtype_to_np_dtype(pd_dtype):
if pd_dtype == paddle.float32:
return np.float32
elif pd_dtype == paddle.float16:
return np.float16
else:
raise ValueError("supported dtype {}".format(pd_dtype))
def gen_random_grad_tensor(grad):
np_dtype = pd_dtype_to_np_dtype(grad.dtype)
grad_np = np.random.random(size=grad.shape).astype(np_dtype)
grad_t = core.Tensor()
grad_t.set(grad_np, paddle.CPUPlace())
return grad_t
def reader():
for _ in range(5):
yield dict(
[(grad.name, gen_random_grad_tensor(grad)) for grad in grads])
scope = paddle.static.Scope()
fetch_list = params
fetches = None
with paddle.static.scope_guard(scope):
dev_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(dev_id)
exe = paddle.static.Executor(place)
exe.run(startup)
if use_fp16:
optimizer.amp_init(place)
master_p_ts = []
for p in params:
p_ts = get_parameter(p.name)
assert len(p_ts) == 2
if p_ts[1] is not None:
master_p_ts.append(p_ts[1])
if use_fp16:
assert len(master_p_ts) > 0
else:
assert len(master_p_ts) == 0
for feed in reader():
fetches = exe.run(main, feed=feed, fetch_list=fetch_list)
return fetches
class TestDistributedFusedLamb(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not paddle.is_compiled_with_cuda():
return
paddle.enable_static()
paddle.set_flags({'FLAGS_cudnn_deterministic': True})
_clip_by_global_norm_using_mp_type(True)
fleet.init(role_maker=get_role_maker())
def config(self):
clip_after_allreduce = bool(
distutils.util.strtobool(
os.getenv('CLIP_AFTER_ALLREDUCE', 'True')))
max_global_norm = float(os.getenv('MAX_GLOBAL_NORM', -1.0))
print('clip_after_allreduce = {}, max_global_norm = {}'.format(
clip_after_allreduce, max_global_norm))
return {
'clip_after_allreduce': clip_after_allreduce,
'grad_clip': paddle.nn.ClipGradByGlobalNorm(max_global_norm)
if max_global_norm > 0 else None,
}
def run_main(self, use_fp16, use_master_param_norm=True):
if not paddle.is_compiled_with_cuda():
return
if not use_fp16:
self.assertTrue(use_master_param_norm)
base_config = self.config()
config1 = dict(base_config)
config1['use_distributed_lamb'] = True
config1['use_fp16'] = use_fp16
config1['use_master_param_norm'] = use_master_param_norm
config2 = dict(base_config)
config2['use_distributed_lamb'] = False
config2['use_fp16'] = use_fp16
config2['use_master_param_norm'] = use_master_param_norm
result1 = run_model(**config1)
result2 = run_model(**config2)
self.assertEqual(len(result1), len(result2))
if use_fp16:
atol = 8e-4 if use_master_param_norm else 1e-3
else:
atol = 1e-7
for ret1, ret2 in zip(result1, result2):
max_diff = np.max(np.abs(ret1 - ret2))
msg = 'max_diff = {} atol = {} when use_fp16 = {} , use_master_param_norm = {}'.format(
max_diff, atol, use_fp16, use_master_param_norm)
self.assertTrue(max_diff < atol, msg)
print(msg)
def test_main(self):
self.run_main(use_fp16=False)
self.run_main(use_fp16=True, use_master_param_norm=True)
self.run_main(use_fp16=True, use_master_param_norm=False)
touch_file_name = os.environ.get('SUCCESS_TOUCH_FILE')
if touch_file_name:
with open(touch_file_name, 'w') as f:
f.write('success')
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shlex
import sys
import shutil
import unittest
import paddle
def get_test_file():
dirname = os.path.dirname(os.path.abspath(__file__))
return os.path.join(dirname, 'distributed_fused_lamb_test_base.py')
def remove_file_if_exists(file_name):
if not os.path.exists(file_name):
return
if os.path.isfile(file_name):
os.remove(file_name)
else:
shutil.rmtree(file_name)
def run_test(clip_after_allreduce=True, max_global_norm=-1.0):
if not paddle.is_compiled_with_cuda():
return
if os.name == 'nt':
return
args = locals()
log_dir = 'log_{}'.format(os.getpid())
cmd = [
sys.executable,
'-u',
'-m',
'paddle.distributed.launch',
'--log_dir',
log_dir,
get_test_file(),
]
cmd = ' '.join([shlex.quote(c) for c in cmd])
os.environ['CLIP_AFTER_ALLREDUCE'] = str(clip_after_allreduce)
os.environ['MAX_GLOBAL_NORM'] = str(max_global_norm)
touch_file_env = 'SUCCESS_TOUCH_FILE'
touch_file_name = 'distributed_fused_lamb_touch_file_{}'.format(os.getpid())
os.environ[touch_file_env] = touch_file_name
remove_file_if_exists(touch_file_name)
try:
assert os.system(cmd) == 0 and os.path.exists(
touch_file_name), 'Test failed when {}'.format(args)
finally:
remove_file_if_exists(touch_file_name)
remove_file_if_exists(log_dir)
class TestDistributedFusedLambWithClip(unittest.TestCase):
def test_1(self):
run_test(clip_after_allreduce=True, max_global_norm=0.01)
def _test_2(self):
run_test(clip_after_allreduce=False, max_global_norm=0.01)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test_distributed_fused_lamb_op_with_clip import run_test
import unittest
class TestDistributedFusedLambWithoutClip(unittest.TestCase):
def test_1(self):
run_test(clip_after_allreduce=True, max_global_norm=-1.0)
def test_2(self):
run_test(clip_after_allreduce=False, max_global_norm=-1.0)
if __name__ == "__main__":
unittest.main()
......@@ -14,6 +14,7 @@
from .optimizer import LookAhead # noqa: F401
from .optimizer import ModelAverage # noqa: F401
from .optimizer import DistributedFusedLamb # noqa: F401
from .checkpoint import auto_checkpoint # noqa: F401
from ..fluid.layer_helper import LayerHelper # noqa: F401
from .operators import softmax_mask_fuse_upper_triangle # noqa: F401
......
......@@ -14,5 +14,6 @@
from .lookahead import LookAhead # noqa: F401
from .modelaverage import ModelAverage # noqa: F401
from .distributed_fused_lamb import DistributedFusedLamb # noqa: F401
__all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import framework, core, layers, unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
from paddle.optimizer import Optimizer
from paddle.distributed import get_rank, get_world_size
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import name_scope
import numpy as np
class DistributedFusedLamb(Optimizer):
def __init__(self,
learning_rate=0.001,
lamb_weight_decay=0.01,
beta1=0.9,
beta2=0.999,
epsilon=1e-6,
parameters=None,
grad_clip=None,
exclude_from_weight_decay_fn=None,
clip_after_allreduce=True,
is_grad_scaled_by_nranks=True,
alignment=128,
use_master_param_norm=True,
name=None):
assert not framework.in_dygraph_mode(
), "DistributedFusedLamb does not support dygraph mode"
super(DistributedFusedLamb, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
weight_decay=None,
grad_clip=None,
name=name)
self._beta1 = beta1
self._beta2 = beta2
self._epsilon = epsilon
self._weight_decay = lamb_weight_decay if lamb_weight_decay is not None else 0.0
if grad_clip is not None:
assert isinstance(
grad_clip, ClipGradByGlobalNorm
), "Only ClipGradByGlobalNorm is supported in DistributedFusedLamb"
max_global_grad_norm = grad_clip.clip_norm
else:
max_global_grad_norm = -1.0
self._max_global_grad_norm = max_global_grad_norm
self._alignment = alignment if alignment is not None else -1
self._clip_after_allreduce = clip_after_allreduce
self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
self._scale = None
self._ring_id = 0
self._use_master_param_norm = use_master_param_norm
self.helper = LayerHelper('distributed_fused_lamb')
self._supports_check_nan_inf = True # very import flag for AMP
main_block = self.helper.main_program.global_block()
self._found_inf = main_block.create_var(
name=unique_name.generate('found_inf'),
shape=[1],
dtype=core.VarDesc.VarType.BOOL)
self._param_to_master_param = {}
def _set_scale(self, scale):
assert scale is not None
if not isinstance(scale, Variable):
scale = self._create_scale_from_constant(scale)
self._scale = scale
def _create_scale_from_constant(self, value):
name = unique_name.generate('global_scale')
return layers.create_global_var(
name=name,
shape=[1],
dtype='float32',
value=float(value),
persistable=True)
def _get_or_create_scale(self):
if self._scale is None:
self._scale = self._create_scale_from_constant(1.0)
return self._scale
def _create_persistable_var(self, name=None, shape=[-1], dtype='float32'):
startup_block = self.helper.startup_program.global_block()
if name is not None:
name = unique_name.generate(name)
startup_var = startup_block.create_var(
name=name,
shape=shape,
dtype=dtype,
persistable=True,
stop_gradient=True)
main_block = self.helper.main_program.global_block()
main_var = main_block.create_var(
name=startup_var.name,
shape=startup_var.shape,
dtype=startup_var.dtype,
persistable=True,
stop_gradient=True)
return main_var
def _get_parameter(self, name, scope=None):
if scope is None:
scope = global_scope()
master_param = self._param_to_master_param.get(name)
assert master_param is not None
master_param_t = scope.find_var(master_param).get_tensor()
assert master_param_t._dtype() == core.VarDesc.VarType.FP32
param_t = scope.find_var(name).get_tensor()
if param_t._dtype() == core.VarDesc.VarType.FP32:
assert param_t._ptr() == master_param_t._ptr()
return param_t, None
else:
assert param_t._dtype() == core.VarDesc.VarType.FP16
assert param_t.shape() == master_param_t.shape()
return param_t, master_param_t
def apply_optimize(self, params_grads):
self.apply_gradients(params_grads)
def apply_gradients(self, params_grads):
flattened = []
for p, g in params_grads:
flattened.extend([p, g])
with flattened[0].block.program._optimized_guard(flattened), name_scope(
"optimizer"):
self._apply_gradients_impl(params_grads)
def _apply_gradients_impl(self, params_grads):
for p, g in params_grads:
assert g.type == core.VarDesc.VarType.LOD_TENSOR, "Only support dense gradient"
g.persistable = True # the gradient must be persistable for fusion
fp32_fused_param = self._create_persistable_var('fp32_fused_param')
fp32_fused_grad = self._create_persistable_var('fp32_fused_grad')
fp16_fused_param = self._create_persistable_var(
'fp16_fused_param', dtype='float16')
fp16_fused_grad = self._create_persistable_var(
'fp16_fused_grad', dtype='float16')
master_params = []
for p, g in params_grads:
master_p = self._create_persistable_var('master_weight')
self._param_to_master_param[p.name] = master_p.name
master_params.append(master_p)
moment1 = self._create_persistable_var('moment1')
moment1.is_distributed = True
moment2 = self._create_persistable_var('moment2')
moment2.is_distributed = True
beta1pow = self._create_persistable_var('beta1pow')
beta2pow = self._create_persistable_var('beta2pow')
fused_indices = self._create_persistable_var(
'fused_indices', dtype='int32')
weight_decay = self._create_persistable_var('weight_decay')
weight_decay.is_distributed = True
param_info = self._create_persistable_var('param_info', dtype='int32')
param_info.is_distributed = True
fused_offsets = self._create_persistable_var('fused_offsets')
fp32_partial_fused_offsets = self._create_persistable_var(
'fp32_partial_fused_offsets', dtype='int32')
fp32_partial_fused_offsets.is_distributed = True
fp16_partial_fused_offsets = self._create_persistable_var(
'fp16_partial_fused_offsets', dtype='int32')
fp16_partial_fused_offsets.is_distributed = True
rank = get_rank()
nranks = get_world_size()
scale = self._get_or_create_scale()
params = [p for p, _ in params_grads]
grads = [g for _, g in params_grads]
weight_decay_values = [self._weight_decay] * len(params)
if self._exclude_from_weight_decay_fn is not None:
for i, p in enumerate(params):
if self._exclude_from_weight_decay_fn(p):
weight_decay_values[i] = 0.0
startup_block = self.helper.startup_program.global_block()
for g in grads:
startup_block.create_var(
name=g.name,
type=g.type,
dtype=g.dtype,
persistable=g.persistable,
shape=g.shape)
startup_block.append_op(
type='distributed_fused_lamb_init',
inputs={
'Param': params,
'Grad': grads,
},
outputs={
'FP32FusedParam': [fp32_fused_param],
'FP32FusedGrad': [fp32_fused_grad],
'FP16FusedParam': [fp16_fused_param],
'FP16FusedGrad': [fp16_fused_grad],
'Moment1': [moment1],
'Moment2': [moment2],
'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow],
'FusedIndices': [fused_indices],
'WeightDecay': [weight_decay],
'GlobalScale': [scale],
'ParamInfo': [param_info],
'ParamOut': params,
'MasterParamOut': master_params,
'GradOut': grads,
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
'FusedParamOffsets': [fused_offsets],
},
attrs={
'alignment': self._alignment,
'rank': rank,
'nranks': nranks,
'weight_decay': weight_decay_values,
'moment1': 0.0,
'moment2': 0.0,
'beta1': self._beta1,
'beta2': self._beta2,
})
main_block = self.helper.main_program.global_block()
self._create_global_learning_rate()
lr = None
for p_g in params_grads:
if lr is None:
lr = self._create_param_lr(p_g)
else:
new_lr = self._create_param_lr(p_g)
assert id(lr) == id(
new_lr
), "The learning rate for each parameter should be the same"
assert lr is not None
lamb_op = main_block.append_op(
type='distributed_fused_lamb',
inputs={
'FP32FusedParam': [fp32_fused_param],
'FP32FusedGrad': [fp32_fused_grad],
'FP16FusedParam': [fp16_fused_param],
'FP16FusedGrad': [fp16_fused_grad],
'LearningRate': [lr],
'Moment1': [moment1],
'Moment2': [moment2],
'Beta1Pow': [beta1pow],
'Beta2Pow': [beta2pow],
'FusedIndices': [fused_indices],
'WeightDecay': [weight_decay],
'GlobalScale': [scale],
'ParamInfo': [param_info],
'Param': params,
'Grad': grads,
'FusedParamOffsets': [fused_offsets],
'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets],
'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets],
},
outputs={
'FP32FusedParamOut': [fp32_fused_param],
'FP16FusedParamOut': [fp16_fused_param],
'Moment1Out': [moment1],
'Moment2Out': [moment2],
'Beta1PowOut': [beta1pow],
'Beta2PowOut': [beta2pow],
'ParamOut': params,
'GradOut': grads,
'FoundInf': [self._found_inf],
},
attrs={
'beta1': self._beta1,
'beta2': self._beta2,
'epsilon': self._epsilon,
'max_global_grad_norm': self._max_global_grad_norm,
'clip_after_allreduce': self._clip_after_allreduce,
'rank': rank,
'ring_id': self._ring_id,
'use_master_param_norm': self._use_master_param_norm,
'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks,
})
return [lamb_op]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册