未验证 提交 4779c2c1 编写于 作者: N niuliling123 提交者: GitHub

Add multi_precision for adagrad op (#50078)

上级 c647cac5
......@@ -43,14 +43,23 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
AddInput("LearningRate", "(Tensor) Learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("MomentOut", "(Tensor) Output second moment");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("epsilon",
"(float, default 1.0e-6) "
"Constant for numerical stability")
.SetDefault(1.0e-6f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddComment(R"DOC(
Adaptive Gradient Algorithm (Adagrad).
......
......@@ -205,6 +205,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
{"adagrad", {"Param", "Grad", "Moment", "LearningRate", "MasterParam"}},
{"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}},
{"nce",
{"Input",
......@@ -361,6 +362,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta2PowOut",
"MasterParamOut"}},
{"sgd", {"ParamOut", "MasterParamOut"}},
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
{"lamb",
{"ParamOut",
"Moment1Out",
......@@ -399,7 +401,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
"MasterParamOut"}},
{"ftrl", {"ParamOut", "SquaredAccumOut", "LinearAccumOut"}},
{"adadelta", {"ParamOut", "AvgSquaredGradOut", "AvgSquaredUpdateOut"}},
{"adagrad", {"ParamOut", "MomentOut"}},
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
{"adamax", {"ParamOut", "MomentOut", "InfNormOut"}},
{"dpsgd", {"ParamOut"}},
{"decayed_adagrad", {"ParamOut", "MomentOut"}},
......
......@@ -29,15 +29,16 @@
inplace : (param -> param_out), (avg_squared_grad -> moment_out), (avg_squared_update -> inf_norm_out)
- op : adagrad_
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, float epsilon)
output : Tensor(param_out), Tensor(moment_out)
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, Tensor master_param, float epsilon, bool multi_precision)
output : Tensor(param_out), Tensor(moment_out), Tensor(master_param_out)
infer_meta :
func : AdagradInferMeta
kernel :
func : adagrad {dense, dense, dense, dense -> dense, dense}
adagrad_dense_param_sparse_grad {dense, selected_rows, dense, dense -> dense, dense}
func : adagrad {dense, dense, dense, dense, dense -> dense, dense, dense}
adagrad_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense-> dense, dense, dense}
data_type : param
inplace : (param -> param_out), (moment -> moment_out)
optional : master_param
inplace : (param -> param_out), (moment -> moment_out), (master_param -> master_param_out)
- op : adam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
......
......@@ -74,9 +74,12 @@ void AdagradInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment_out) {
MetaTensor* moment_out,
MetaTensor* master_param_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
......
......@@ -53,9 +53,12 @@ void AdagradInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment_out);
MetaTensor* moment_out,
MetaTensor* master_param_out);
void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
......
......@@ -25,9 +25,12 @@ void AdagradDenseKernel(const Context& dev_ctx,
const DenseTensor& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment_out);
DenseTensor* moment_out,
DenseTensor* master_param_outs);
template <typename T, typename Context>
void AdagradSparseKernel(const Context& dev_ctx,
......@@ -35,8 +38,11 @@ void AdagradSparseKernel(const Context& dev_ctx,
const SelectedRows& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment_out);
DenseTensor* moment_out,
DenseTensor* master_param_outs);
} // namespace phi
......@@ -28,6 +28,42 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
}
} // namespace
template <typename T>
struct DenseAdagradFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& ctx,
const DenseTensor& param_t,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);
T epsilon = static_cast<T>(epsilon_t);
auto param = EigenVector<T>::Flatten(param_t);
auto grad = EigenVector<T>::Flatten(grad_t);
auto moment = EigenVector<T>::Flatten(moment_t);
auto param_out = EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = EigenVector<T>::Flatten(*moment_out_tensor);
auto place = *ctx.eigen_device();
moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
auto* lr = learning_rate.data<T>();
param_out.device(place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
}
};
template <typename T>
struct SparseAdagradFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& context,
......@@ -67,6 +103,8 @@ struct SparseAdagradFunctor<phi::CPUContext, T> {
template struct SparseAdagradFunctor<phi::CPUContext, float>;
template struct SparseAdagradFunctor<phi::CPUContext, double>;
template struct DenseAdagradFunctor<phi::CPUContext, float>;
template struct DenseAdagradFunctor<phi::CPUContext, double>;
} // namespace phi
......
......@@ -13,9 +13,11 @@
// limitations under the License.
#include "paddle/phi/kernels/adagrad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
......@@ -23,6 +25,79 @@
namespace phi {
template <typename T, typename MT>
__global__ void AdagradGPUKernel(const T* param,
const T* grad,
const MT* moment,
const MT* lr,
const MT* master_param,
MT epsilon,
T* param_out,
MT* moment_out,
MT* master_param_out,
int num) {
auto idx = blockDim.x * blockIdx.x + threadIdx.x;
MT lr_data = static_cast<T>(lr[0]);
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
MT grad_data = static_cast<MT>(grad[i]);
MT moment_out_data = static_cast<MT>(moment[i]) + grad_data * grad_data;
moment_out[i] = static_cast<MT>(moment_out_data);
auto in = master_param_out ? master_param[i] : static_cast<MT>(param[i]);
MT param_out_data =
in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon);
param_out[i] = static_cast<MT>(param_out_data);
if (master_param_out) {
master_param_out[i] = param_out_data;
}
}
}
template <typename T>
struct DenseAdagradFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& ctx,
const DenseTensor& param_t,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::template MPTypeTrait<T>::Type;
T* param_out_data = ctx.template Alloc<T>(param_out_tensor);
MPDType* moment_out_data = ctx.template Alloc<MPDType>(moment_out_tensor);
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision ? ctx.template Alloc<MPDType>(master_param_outs)
: nullptr;
MPDType epsilon = static_cast<MPDType>(epsilon_t);
int numel = param_t.numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = ctx.stream();
AdagradGPUKernel<T, MPDType>
<<<block, grid, 0, stream>>>(param_t.data<T>(),
grad_t.data<T>(),
moment_t.data<MPDType>(),
learning_rate.data<MPDType>(),
master_in_data,
epsilon,
param_out_data,
moment_out_data,
master_out_data,
numel);
}
};
template <typename T, int block_size>
__global__ void MergeGradKernel(const T* grad,
const int64_t* grad_rows,
......@@ -123,11 +198,19 @@ struct SparseAdagradFunctor<phi::GPUContext, T> {
template struct SparseAdagradFunctor<phi::GPUContext, float>;
template struct SparseAdagradFunctor<phi::GPUContext, double>;
template struct DenseAdagradFunctor<phi::GPUContext, float>;
template struct DenseAdagradFunctor<phi::GPUContext, double>;
template struct DenseAdagradFunctor<phi::GPUContext, phi::dtype::float16>;
} // namespace phi
PD_REGISTER_KERNEL(
adagrad, GPU, ALL_LAYOUT, phi::AdagradDenseKernel, float, double) {}
PD_REGISTER_KERNEL(adagrad,
GPU,
ALL_LAYOUT,
phi::AdagradDenseKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(adagrad_dense_param_sparse_grad,
GPU,
......
......@@ -30,6 +30,21 @@ struct SparseAdagradFunctor {
DenseTensor* param);
};
template <typename DeviceContext, typename T>
struct DenseAdagradFunctor {
void operator()(const DeviceContext& ctx,
const DenseTensor& param_t,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs);
};
template <typename DeviceContext, typename T>
phi::SelectedRows SquareSelectedRows(const DeviceContext& context,
const phi::SelectedRows& input) {
......@@ -50,35 +65,24 @@ void AdagradDenseKernel(const Context& ctx,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);
T epsilon = static_cast<T>(epsilon_t);
auto param = EigenVector<T>::Flatten(param_t);
auto grad = EigenVector<T>::Flatten(grad_t);
auto moment = EigenVector<T>::Flatten(moment_t);
auto param_out = EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = EigenVector<T>::Flatten(*moment_out_tensor);
auto place = *ctx.eigen_device();
moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate.data<T>();
param_out.device(place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else {
auto lr = EigenVector<T>::Flatten(learning_rate);
param_out.device(place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
DenseAdagradFunctor<Context, T> functor;
functor(ctx,
param_t,
grad_t,
moment_t,
learning_rate,
master_param,
epsilon_t,
multi_precision,
param_out_tensor,
moment_out_tensor,
master_param_outs);
}
template <typename T, typename Context>
......@@ -87,9 +91,12 @@ void AdagradSparseKernel(const Context& ctx,
const SelectedRows& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment_out) {
DenseTensor* moment_out,
DenseTensor* master_param_outs) {
auto* param_out_tensor = param_out;
auto* moment_out_tensor = moment_out;
......
......@@ -24,9 +24,12 @@ void AdagradDenseKernel(const Context& ctx,
const DenseTensor& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor) {
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);
......
......@@ -18,15 +18,17 @@ namespace phi {
KernelSignature AdagradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("Grad")) {
return KernelSignature("adagrad",
{"Param", "Grad", "Moment", "LearningRate"},
{"epsilon"},
{"ParamOut", "MomentOut"});
return KernelSignature(
"adagrad",
{"Param", "Grad", "Moment", "LearningRate", "MasterParam"},
{"epsilon", "multi_precision"},
{"ParamOut", "MomentOut", "MasterParamOut"});
} else if (ctx.IsSelectedRowsInput("Grad")) {
return KernelSignature("adagrad_dense_param_sparse_grad",
{"Param", "Grad", "Moment", "LearningRate"},
{"epsilon"},
{"ParamOut", "MomentOut"});
return KernelSignature(
"adagrad_dense_param_sparse_grad",
{"Param", "Grad", "Moment", "LearningRate", "MasterParam"},
{"epsilon", "multi_precision"},
{"ParamOut", "MomentOut", "MasterParamOut"});
}
return KernelSignature("unregistered", {}, {}, {});
......
......@@ -2079,13 +2079,83 @@ class AdagradOptimizer(Optimizer):
name=name,
)
self.type = "adagrad"
self._multi_precision = False
self._epsilon = epsilon
self.initial_accumulator_value = initial_accumulator_value
self._master_weights = {}
def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + '_fp32_master'
var_name = unique_name.generate(var_name)
var = paddle.static.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True,
)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
},
)
self._master_weights[param.name] = var
return var
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = (
self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
)
target_param = (
self._master_weights[param.name] if find_master else param
)
target_name = target_param.name
if (
name not in self._accumulators
or target_name not in self._accumulators[name]
):
raise Exception(
"Accumulator {} does not exist for parameter {}".format(
name, target_name
)
)
return self._accumulators[name][target_name]
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._moment_acc_str, master_p)
continue
if (
p.dtype == core.VarDesc.VarType.FP16
and not self._multi_precision
):
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Lars optimizer."
)
self._add_accumulator(
self._moment_acc_str,
p,
......@@ -2098,30 +2168,52 @@ class AdagradOptimizer(Optimizer):
moment_acc = self._get_accumulator(
self._moment_acc_str, param_and_grad[0]
)
find_master = (
self._multi_precision
and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
)
master_weight = (
self._master_weights[param_and_grad[0].name]
if find_master
else None
)
if in_dygraph_mode():
_C_ops.adagrad_(
param_and_grad[0],
param_and_grad[1],
moment_acc,
self._create_param_lr(param_and_grad),
master_weight,
self._epsilon,
find_master,
)
return None
else:
# Create the adagrad optimizer op
adagrad_op = block.append_op(
type=self.type,
inputs={
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Moment": moment_acc,
"LearningRate": self._create_param_lr(param_and_grad),
},
outputs={
}
outputs = {
"ParamOut": param_and_grad[0],
"MomentOut": moment_acc,
},
attrs={"epsilon": self._epsilon},
}
attrs = {"epsilon": self._epsilon, "multi_precision": find_master}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
adagrad_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True,
)
......
......@@ -23,8 +23,24 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
def adamgrad_wrapper(param, grad, moment, learning_rate, epsilon):
paddle._C_ops.adagrad_(param, grad, moment, learning_rate, epsilon)
def adamgrad_wrapper(
param,
grad,
moment,
learning_rate,
master_weight=None,
epsilon=1e-8,
multi_precision=False,
):
paddle._C_ops.adagrad_(
param,
grad,
moment,
learning_rate,
master_weight,
epsilon,
multi_precision,
)
class TestAdagradOp1(OpTest):
......@@ -79,7 +95,7 @@ class TestAdagradOp2(OpTest):
'LearningRate': np.array([lr]).astype("float32"),
}
self.attrs = {'epsilon': epsilon}
self.attrs = {'epsilon': epsilon, "multi_precision": False}
moment_out = moment + grad * grad
param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon)
......@@ -124,7 +140,6 @@ class TestSparseAdagradOp(unittest.TestCase):
moment_np_array = np.full((height, row_numel), 2.0).astype("float32")
moment.set(moment_np_array, place)
# create and run sgd operator
adagrad_op = Operator(
"adagrad",
Param='Param',
......@@ -196,6 +211,271 @@ class TestSparseAdagradOp(unittest.TestCase):
self.check_with_place(place)
class TestAdagradOpMultiPrecison(unittest.TestCase):
def _test_adagrad_op_dygraph_place_amp(self, place, use_amp=False):
import paddle
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
input = paddle.randn((5, 5))
model = paddle.nn.Linear(5, 5)
optimizer = paddle.optimizer.Adagrad(0.1, parameters=model.parameters())
optimizer._multi_precision = use_amp
for idx in range(2):
if place == 'gpu' and use_amp:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if place == 'gpu' and use_amp:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
optimizer.clear_grad()
else:
output = model(input)
loss = paddle.mean(output)
loss.backward()
optimizer.step()
optimizer.clear_grad()
paddle.enable_static()
def _get_places(self):
import paddle
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._test_adagrad_op_dygraph_place_amp(place, use_amp)
class TestAdagradMultiPrecision2_0(unittest.TestCase):
def dygraph_adagrad_mp(self, mp, use_amp):
paddle.disable_static()
paddle.seed(100)
paddle.set_device('gpu')
input = paddle.randn((2, 2))
model = paddle.nn.Linear(2, 2)
optimizer = paddle.optimizer.Adagrad(0.5, parameters=model.parameters())
optimizer._multi_precision = mp
if use_amp:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
for idx in range(5):
if use_amp:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
else:
output = model(input)
loss = paddle.mean(output)
loss.backward()
optimizer.step()
optimizer.clear_grad()
return output, model.parameters()
def static_adagrad_mp(self, mp, use_amp):
paddle.enable_static()
paddle.seed(100)
np.random.seed(100)
exe = paddle.static.Executor('gpu')
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.optimizer.Adagrad(0.1)
optimizer._multi_precision = mp
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False,
)
with paddle.static.program_guard(train_program, startup_program):
if use_amp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16'
)
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32'
)
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
(loss_data,) = exe.run(
train_program, feed={"X": x}, fetch_list=[loss.name]
)
out.append(loss_data)
return out
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
"Test dygraph mode"
output1_dy, params1_dy = self.dygraph_adagrad_mp(use_amp=True, mp=True)
output2_dy, params2_dy = self.dygraph_adagrad_mp(
use_amp=False, mp=False
)
np.testing.assert_allclose(
output1_dy.astype('float32').numpy(),
output2_dy.astype('float32').numpy(),
rtol=1e-05,
atol=0.1,
)
for idx in range(len(params1_dy)):
np.testing.assert_allclose(
params1_dy[idx].astype('float32').numpy(),
params2_dy[idx].astype('float32').numpy(),
rtol=1e-05,
atol=0.1,
)
"Test static mode"
output1_st = self.static_adagrad_mp(use_amp=True, mp=True)
output2_st = self.static_adagrad_mp(use_amp=False, mp=False)
for idx in range(len(output1_st)):
np.testing.assert_allclose(
output1_st[idx].astype('float32'),
output2_st[idx].astype('float32'),
rtol=1e-05,
atol=0.1,
)
class TestAdagradMultiPrecision1_0(unittest.TestCase):
def dygraph_adagrad_mp(self, use_amp, mp):
paddle.disable_static()
paddle.seed(10)
paddle.set_device('gpu')
input = paddle.randn((2, 2))
model = paddle.nn.Linear(2, 2)
optimizer = paddle.fluid.optimizer.Adagrad(
learning_rate=0.001, parameter_list=model.parameters()
)
optimizer._multi_precision = mp
if use_amp:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
for idx in range(5):
if use_amp:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_gradients()
else:
output = model(input)
loss = paddle.mean(output)
optimizer.minimize(loss)
optimizer.clear_gradients()
return output, model.parameters()
def static_adagrad_mp(self, use_amp, mp):
paddle.enable_static()
paddle.seed(100)
np.random.seed(100)
exe = paddle.static.Executor('gpu')
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.fluid.optimizer.Adagrad(learning_rate=0.001)
optimizer._multi_precision = mp
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False,
)
with paddle.static.program_guard(train_program, startup_program):
if use_amp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16'
)
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32'
)
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
(loss_data,) = exe.run(
train_program, feed={"X": x}, fetch_list=[loss.name]
)
out.append(loss_data)
return out
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
"Test dygraph mode"
output1_dy, params1_dy = self.dygraph_adagrad_mp(use_amp=True, mp=True)
output2_dy, params2_dy = self.dygraph_adagrad_mp(
use_amp=False, mp=False
)
np.testing.assert_allclose(
output1_dy.astype('float32').numpy(),
output2_dy.astype('float32').numpy(),
rtol=1e-05,
atol=0.1,
)
for idx in range(len(params1_dy)):
np.testing.assert_allclose(
params1_dy[idx].astype('float32').numpy(),
params2_dy[idx].astype('float32').numpy(),
rtol=1e-05,
atol=0.1,
)
"Test static mode"
output1_st = self.static_adagrad_mp(use_amp=True, mp=True)
output2_st = self.static_adagrad_mp(use_amp=False, mp=False)
for idx in range(len(output1_st)):
np.testing.assert_allclose(
output1_st[idx].astype('float32'),
output2_st[idx].astype('float32'),
rtol=1e-05,
atol=0.1,
)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from ..fluid import framework
import paddle
from ..fluid import core, framework, unique_name
from ..fluid.layer_helper import LayerHelper
from .optimizer import Optimizer
__all__ = []
......@@ -126,12 +130,72 @@ class Adagrad(Optimizer):
)
self.type = "adagrad"
self._epsilon = epsilon
self._multi_precision = False
self._master_weights = {}
self.initial_accumulator_value = initial_accumulator_value
self._default_dict = {
'epsilon': epsilon,
'initial_accumulator_value': initial_accumulator_value,
}
def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
else:
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = paddle.static.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True,
)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
},
)
self._master_weights[param.name] = var
return var
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = (
self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
)
target_param = (
self._master_weights[param.name] if find_master else param
)
target_name = target_param.name
if (
name not in self._accumulators
or target_name not in self._accumulators[name]
):
raise Exception(
"Accumulator {} does not exist for parameter {}".format(
name, target_name
)
)
return self._accumulators[name][target_name]
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
......@@ -139,6 +203,18 @@ class Adagrad(Optimizer):
parameters = self._update_param_group(parameters)
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._moment_acc_str, master_p)
continue
if (
p.dtype == core.VarDesc.VarType.FP16
and not self._multi_precision
):
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(
self._moment_acc_str,
p,
......@@ -154,17 +230,37 @@ class Adagrad(Optimizer):
moment_acc = self._get_accumulator(
self._moment_acc_str, param_and_grad[0]
)
find_master = (
self._multi_precision
and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
)
master_weight = (
self._master_weights[param_and_grad[0].name]
if find_master
else None
)
# Create the adagrad optimizer op
adagrad_op = block.append_op(
type=self.type,
inputs={
inputs = {
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Moment": moment_acc,
"LearningRate": self._create_param_lr(param_and_grad),
},
outputs={"ParamOut": param_and_grad[0], "MomentOut": moment_acc},
attrs={"epsilon": self._epsilon},
}
outputs = {"ParamOut": param_and_grad[0], "MomentOut": moment_acc}
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
adagrad_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs={"epsilon": self._epsilon, "multi_precision": find_master},
stop_gradient=True,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册