未验证 提交 fb3313e9 编写于 作者: Z zhangbo9674 提交者: GitHub

Add multi tensor for adam (#38010)

* add multi tensor for adam

* add merged_adam op

* refine code

* refine adam compute logic
上级 0883cf37
......@@ -29,20 +29,18 @@ __global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_,
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = moment1[id];
MT mom2 = moment2[id];
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
......@@ -65,9 +63,6 @@ __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
......@@ -77,8 +72,9 @@ __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
......@@ -105,8 +101,6 @@ __global__ void SparseAdamCUDAKernelREG(
int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
for (; id < ndim; id += blockDim.x * gridDim.x) {
auto row_idx =
......@@ -122,8 +116,10 @@ __global__ void SparseAdamCUDAKernelREG(
: static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
MT denom =
(sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
// Write back to global memory
mom1_out_[id] = mom1;
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/merged_adam_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MergedAdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto param_dtype =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(param_dtype, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "Beta1Pow" || var_name == "Beta2Pow" ||
var_name == "SkipUpdate") {
return expected_kernel_type;
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
class MergedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "(Tensor, default Tensor<float>) Input parameter")
.AsDuplicable();
AddInput("Grad", "(Tensor, default Tensor<float>) Input gradient")
.AsDuplicable();
AddInput("LearningRate", "(Tensor, default Tensor<float>) Learning rate")
.AsDuplicable();
AddInput("Moment1", "(Tensor, default Tensor<float>) Input first moment")
.AsDuplicable();
AddInput("Moment2", "(Tensor, default Tensor<float>) Input second moment")
.AsDuplicable();
AddInput("Beta1Pow",
"(Tensor, default Tensor<float>) Input beta1 power accumulator")
.AsDuplicable();
AddInput("Beta2Pow",
"(Tensor, default Tensor<float>) Input beta2 power accumulator")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDispensable()
.AsDuplicable();
AddOutput("ParamOut", "(Tensor) Output parameter").AsDuplicable();
AddOutput("Moment1Out", "(Tensor) Output first moment").AsDuplicable();
AddOutput("Moment2Out", "(Tensor) Output second moment").AsDuplicable();
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator")
.AsDuplicable();
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable()
.AsDuplicable();
AddAttr<float>("beta1",
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates.")
.SetDefault(0.9f);
AddAttr<float>("beta2",
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates.")
.SetDefault(0.999f);
AddAttr<float>("epsilon",
"(float, default 1.0e-8) "
"Constant for numerical stability")
.SetDefault(1.0e-8f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
// TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut
// as dispensable since they are not used when use_global_beta_pow is true.
AddAttr<bool>("use_global_beta_pow",
"(bool, default false) "
"Whether to use global beta_pow for whole model instead of "
"creating beta_pow for each parameter.")
.SetDefault(false);
AddComment(R"DOC(
Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
$$
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
learning\_rate = learning\_rate *
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(merged_adam, ops::MergedAdamOp,
ops::MergedAdamOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(merged_adamw, ops::MergedAdamOp,
ops::MergedAdamOpMaker);
REGISTER_OP_CPU_KERNEL(
merged_adam,
ops::MergedAdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::MergedAdamOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/merged_adam_op.h"
namespace paddle {
namespace operators {
template <typename T, typename MT>
__global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_,
MT beta2_pow_, const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out, const MT* lr_,
const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T, typename MT>
__global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow_, const MT* beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out, const MT* lr_,
const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T>
__global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_,
const T* beta2_pow_, T* beta1_pow_out,
T* beta2_pow_out) {
*beta1_pow_out = beta1 * beta1_pow_[0];
*beta2_pow_out = beta2 * beta2_pow_[0];
}
template <typename T>
class MergedAdamOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using MPDType = typename details::MPTypeTrait<T>::Type;
auto param = ctx.MultiInput<framework::Tensor>("Param");
auto grad = ctx.MultiInput<framework::Tensor>("Grad");
auto lr = ctx.MultiInput<framework::Tensor>("LearningRate");
auto mom1 = ctx.MultiInput<framework::Tensor>("Moment1");
auto mom2 = ctx.MultiInput<framework::Tensor>("Moment2");
auto beta1_pow = ctx.MultiInput<framework::Tensor>("Beta1Pow");
auto beta2_pow = ctx.MultiInput<framework::Tensor>("Beta2Pow");
auto param_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto mom1_out = ctx.MultiOutput<framework::Tensor>("Moment1Out");
auto mom2_out = ctx.MultiOutput<framework::Tensor>("Moment2Out");
auto beta1_pow_out = ctx.MultiOutput<framework::Tensor>("Beta1PowOut");
auto beta2_pow_out = ctx.MultiOutput<framework::Tensor>("Beta2PowOut");
MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
const bool multi_precision = ctx.Attr<bool>("multi_precision");
auto master_param = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_param_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
size_t param_num = param.size();
for (size_t idx = 0; idx < param_num; idx++) {
const MPDType* master_in_data =
multi_precision ? master_param[idx]->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out[idx]->mutable_data<MPDType>(ctx.GetPlace())
: nullptr;
// update param and moment
int threads = 512;
int blocks = (param[idx]->numel() + threads - 1) / threads;
if (beta1_pow[idx]->place() == platform::CPUPlace() &&
beta2_pow[idx]->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow[idx]->data<MPDType>(),
*beta2_pow[idx]->data<MPDType>(), mom1[idx]->data<MPDType>(),
mom1_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
mom2[idx]->data<MPDType>(),
mom2_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
lr[idx]->data<MPDType>(), grad[idx]->data<T>(),
param[idx]->data<T>(),
param_out[idx]->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, param[idx]->numel());
if (!use_global_beta_pow) {
// Cpu update
beta1_pow_out[idx]->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta1 * beta1_pow[idx]->data<MPDType>()[0];
beta2_pow_out[idx]->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta2 * beta2_pow[idx]->data<MPDType>()[0];
}
} else {
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(), mom1[idx]->data<MPDType>(),
mom1_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
mom2[idx]->data<MPDType>(),
mom2_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
lr[idx]->data<MPDType>(), grad[idx]->data<T>(),
param[idx]->data<T>(),
param_out[idx]->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, param[idx]->numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow[idx]->data<MPDType>(),
beta2_pow[idx]->data<MPDType>(),
beta1_pow_out[idx]->mutable_data<MPDType>(ctx.GetPlace()),
beta2_pow_out[idx]->mutable_data<MPDType>(ctx.GetPlace()));
}
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(merged_adam, ops::MergedAdamOpCUDAKernel<float>,
ops::MergedAdamOpCUDAKernel<double>,
ops::MergedAdamOpCUDAKernel<plat::float16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/adam_op.h"
namespace paddle {
namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename DeviceContext, typename T>
class MergedAdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.MultiInput<framework::Tensor>("Param");
size_t n = param.size();
auto grad = ctx.MultiInput<framework::Tensor>("Grad");
PADDLE_ENFORCE_EQ(n, grad.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to "
"Input(Param), but got the size of Input(Grad) "
"is %d, the size of Input(Param) is %d.",
grad.size(), n));
auto lr = ctx.MultiInput<framework::Tensor>("LearningRate");
PADDLE_ENFORCE_EQ(
n, lr.size(),
platform::errors::InvalidArgument(
"The size of Input(LearningRate) must be equal to "
"Input(Param), but got the size of Input(LearningRate) "
"is %d, the size of Input(Param) is %d.",
lr.size(), n));
auto mom1 = ctx.MultiInput<framework::Tensor>("Moment1");
PADDLE_ENFORCE_EQ(n, mom1.size(),
platform::errors::InvalidArgument(
"The size of Input(Moment1) must be equal to "
"Input(Param), but got the size of Input(Moment1) "
"is %d, the size of Input(Param) is %d.",
mom1.size(), n));
auto mom2 = ctx.MultiInput<framework::Tensor>("Moment2");
PADDLE_ENFORCE_EQ(n, mom2.size(),
platform::errors::InvalidArgument(
"The size of Input(Moment2) must be equal to "
"Input(Param), but got the size of Input(Moment2) "
"is %d, the size of Input(Param) is %d.",
mom2.size(), n));
auto beta1_pow = ctx.MultiInput<framework::Tensor>("Beta1Pow");
PADDLE_ENFORCE_EQ(n, beta1_pow.size(),
platform::errors::InvalidArgument(
"The size of Input(Beta1Pow) must be equal to "
"Input(Param), but got the size of Input(Beta1Pow) "
"is %d, the size of Input(Param) is %d.",
beta1_pow.size(), n));
auto beta2_pow = ctx.MultiInput<framework::Tensor>("Beta2Pow");
PADDLE_ENFORCE_EQ(n, beta2_pow.size(),
platform::errors::InvalidArgument(
"The size of Input(Beta2Pow) must be equal to "
"Input(Param), but got the size of Input(Beta2Pow) "
"is %d, the size of Input(Param) is %d.",
beta2_pow.size(), n));
auto param_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto mom1_out = ctx.MultiOutput<framework::Tensor>("Moment1Out");
auto mom2_out = ctx.MultiOutput<framework::Tensor>("Moment2Out");
auto beta1_pow_out = ctx.MultiOutput<framework::Tensor>("Beta1PowOut");
auto beta2_pow_out = ctx.MultiOutput<framework::Tensor>("Beta2PowOut");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
size_t param_num = param.size();
for (size_t idx = 0; idx < param_num; idx++) {
AdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow[idx]->data<T>(),
beta2_pow[idx]->data<T>(), mom1[idx]->data<T>(),
mom1_out[idx]->mutable_data<T>(ctx.GetPlace()), mom2[idx]->data<T>(),
mom2_out[idx]->mutable_data<T>(ctx.GetPlace()), lr[idx]->data<T>(),
grad[idx]->data<T>(), param[idx]->data<T>(),
param_out[idx]->mutable_data<T>(ctx.GetPlace()));
functor(param[idx]->numel());
if (!use_global_beta_pow) {
beta1_pow_out[idx]->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow[idx]->data<T>()[0];
beta2_pow_out[idx]->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow[idx]->data<T>()[0];
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -71,6 +71,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adam",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"merged_adam",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
......@@ -123,6 +126,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"merged_adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
......@@ -148,6 +154,9 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"merged_adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
"MasterParamOut"}},
......
......@@ -1011,5 +1011,186 @@ class TestAdamOpV2Group(TestAdamOpV2):
adam.clear_gradients()
class TestMultiTensorAdam(unittest.TestCase):
def _adam_optimize_dygraph(self,
place,
use_param_attr=False,
use_param_group=False,
use_amp=False,
use_multi_tensor=False):
paddle.disable_static()
paddle.seed(10)
paddle.set_device(place)
input = paddle.randn((5, 5))
weight_attr = paddle.ParamAttr(
learning_rate=0.5,
regularizer=paddle.regularizer.L2Decay(1.0),
trainable=True)
if use_param_attr:
model = paddle.nn.Linear(5, 5, weight_attr)
else:
model = paddle.nn.Linear(5, 5)
if not use_param_group:
optimizer = paddle.optimizer.Adam(
parameters=model.parameters(),
use_multi_tensor=use_multi_tensor,
multi_precision=use_amp)
else:
optimizer = paddle.optimizer.Adam(
parameters=[{
'params': model.parameters(),
'weight_decay': 0.001,
'beta1': 0.1,
'beta2': 0.99
}],
use_multi_tensor=use_multi_tensor,
multi_precision=use_amp)
for idx in range(2):
if place == 'gpu' and use_amp == True:
model = paddle.amp.decorate(models=model, level='O2')
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
if place == 'gpu' and use_amp == True:
with paddle.amp.auto_cast(level='O2'):
output = model(input)
loss = paddle.mean(output)
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
optimizer.clear_grad()
else:
output = model(input)
loss = paddle.mean(output)
loss.backward()
optimizer.step()
optimizer.clear_grad()
return output, model.parameters()
def _adam_optimize_static(self,
place,
use_amp=False,
use_multi_tensor=False):
paddle.enable_static()
paddle.seed(10)
np.random.seed(10)
if place == 'cpu':
use_amp = False
exe = paddle.static.Executor(place=place)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
optimizer = paddle.optimizer.Adam(
multi_precision=use_amp, use_multi_tensor=use_multi_tensor)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True,
use_fp16_guard=False)
with paddle.static.program_guard(train_program, startup_program):
if use_amp:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16')
else:
data = paddle.static.data(
shape=[2, 2], name='X', dtype='float32')
hidden = paddle.static.nn.fc(x=data, size=10)
loss = paddle.fluid.layers.mean(hidden)
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
out = []
for idx in range(5):
loss_data, = exe.run(train_program,
feed={"X": x},
fetch_list=[loss.name])
out.append(loss_data)
return out
def _get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def _check_with_place_amp(self, place, use_amp):
# test dygraph mode
output_dygraph1, params_dygraph1 = self._adam_optimize_dygraph(
place=place, use_amp=use_amp, use_multi_tensor=True)
output_dygraph2, params_dygraph2 = self._adam_optimize_dygraph(
place=place, use_amp=use_amp, use_multi_tensor=False)
self.assertEqual(
np.allclose(
output_dygraph1, output_dygraph2, rtol=1e-05), True)
for idx in range(len(params_dygraph1)):
self.assertEqual(
np.allclose(
params_dygraph1[idx], params_dygraph2[idx], rtol=1e-05),
True)
# test static mode
output_static1 = self._adam_optimize_static(
place=place, use_amp=use_amp, use_multi_tensor=True)
output_static2 = self._adam_optimize_static(
place=place, use_amp=use_amp, use_multi_tensor=False)
for idx in range(len(output_static1)):
self.assertEqual(
np.allclose(
output_static1[idx], output_static2[idx], rtol=1e-05),
True)
def _check_with_param_arrt(self, place, use_amp):
output1, params1 = self._adam_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_attr=True,
use_multi_tensor=True)
output2, params2 = self._adam_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_attr=True,
use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)
def _check_with_param_group(self, place, use_amp):
output1, params1 = self._adam_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_group=True,
use_multi_tensor=True)
output2, params2 = self._adam_optimize_dygraph(
place=place,
use_amp=use_amp,
use_param_group=True,
use_multi_tensor=False)
self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True)
for idx in range(len(params1)):
self.assertEqual(
np.allclose(
params1[idx], params2[idx], rtol=1e-05), True)
def test_main(self):
for place in self._get_places():
use_amp_list = [True, False]
for use_amp in use_amp_list:
self._check_with_place_amp(place, use_amp)
self._check_with_param_arrt(place, use_amp)
self._check_with_param_group(place, use_amp)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
from paddle import _C_ops
def run_adam_op(params,
grads,
lrs,
moment1s,
moment2s,
beta1_pows,
beta2_pows,
master_params,
epsilon,
beta1,
beta2,
place,
multi_precision=False,
use_merged=False):
assert len(params) == len(grads)
assert len(params) == len(lrs)
assert len(params) == len(moment1s)
assert len(params) == len(moment2s)
assert len(params) == len(beta1_pows)
assert len(params) == len(beta1_pows)
assert len(params) == len(master_params)
paddle.disable_static()
paddle.set_device(place)
param_vars = [paddle.fluid.dygraph.to_variable(p) for p in params]
grad_vars = [paddle.fluid.dygraph.to_variable(g) for g in grads]
lr_vars = [paddle.fluid.dygraph.to_variable(l) for l in lrs]
moment1_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment1s]
moment2_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment2s]
beta1_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta1_pows]
beta2_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta2_pows]
master_param_vars = [
paddle.fluid.dygraph.to_variable(m_p) for m_p in master_params
]
if not use_merged:
for i in range(len(param_vars)):
_, _, _, _, _, _ = _C_ops.adam(
param_vars[i], grad_vars[i], lr_vars[i], moment1_vars[i],
moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i],
master_param_vars[i], param_vars[i], moment1_vars[i],
moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i],
master_param_vars[i], 'epsilon', epsilon, 'beta1', beta1,
'beta2', beta2, 'multi_precision', multi_precision)
else:
_, _, _, _, _, _ = _C_ops.merged_adam(
param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars,
beta1_pow_vars, beta2_pow_vars, master_param_vars, param_vars,
moment1_vars, moment2_vars, beta1_pow_vars, beta2_pow_vars,
master_param_vars, 'epsilon', epsilon, 'beta1', beta1, 'beta2',
beta2, 'multi_precision', multi_precision)
outputs = {
'ParamOut': param_vars,
'Moment1Out': moment1_vars,
'Moment2Out': moment2_vars,
'Beta1PowOut': beta1_pow_vars,
'Beta2PowOut': beta2_pow_vars,
'MasterParamOut': master_param_vars
}
return outputs
class TestMergedAdam(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]]
self.seed = 10
def gen_rand_data(self, shapes, dtype):
return [np.random.random(s).astype(dtype) for s in shapes]
def prepare_data(self, shapes, multi_precision, seed, place):
np.random.seed(seed)
mp_dtype = np.float32
dtype = np.float16 if multi_precision and place == 'gpu' else np.float32
params = self.gen_rand_data(shapes, dtype)
grads = self.gen_rand_data(shapes, dtype)
lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
moment1s = self.gen_rand_data(shapes, mp_dtype)
moment2s = self.gen_rand_data(shapes, mp_dtype)
beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype)
master_params = [p.astype(mp_dtype) for p in params]
return params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params
def check_with_place(self, place, multi_precision):
params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params = self.prepare_data(
self.shapes, multi_precision, self.seed, place)
def run_op(use_merged):
return run_adam_op(
params=params,
grads=grads,
lrs=lrs,
moment1s=moment1s,
moment2s=moment2s,
beta1_pows=beta1_pows,
beta2_pows=beta2_pows,
master_params=master_params,
epsilon=0.9,
beta1=0.9,
beta2=0.99,
place=place,
multi_precision=multi_precision,
use_merged=use_merged)
outs1 = run_op(True)
outs2 = run_op(False)
self.assertEqual(len(outs1), len(outs2))
for key in outs1.keys():
value1 = outs1[key]
value2 = outs2[key]
for i in range(len(value1)):
if place == 'gpu':
self.assertTrue(np.array_equal(value1[i], value2[i]))
else:
self.assertTrue(
np.allclose(
value1[i], value2[i], atol=1e-7))
def get_places(self):
places = ['cpu']
if paddle.is_compiled_with_cuda():
places.append('gpu')
return places
def test_main(self):
for multi_precision in [False, True]:
for place in self.get_places():
self.check_with_place(place, multi_precision)
if __name__ == "__main__":
unittest.main()
......@@ -92,6 +92,7 @@ class Adam(Optimizer):
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
......@@ -172,6 +173,7 @@ class Adam(Optimizer):
grad_clip=None,
lazy_mode=False,
multi_precision=False,
use_multi_tensor=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
......@@ -209,6 +211,24 @@ class Adam(Optimizer):
'lazy_mode': lazy_mode,
}
self._use_multi_tensor = use_multi_tensor
if self._use_multi_tensor:
self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
self._moment1_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
self._moment2_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
self._beta1_pow_acc_dict = {
'FP32_LODTensor': [],
'FP16_LODTensor': []
}
self._beta2_pow_acc_dict = {
'FP32_LODTensor': [],
'FP16_LODTensor': []
}
self._master_weight_dict = {
'FP32_LODTensor': None,
'FP16_LODTensor': []
}
def _create_master_weight(self, param):
if param.name in self._master_weights:
var = self._master_weights[param.name]
......@@ -436,6 +456,157 @@ class Adam(Optimizer):
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads)
def _multi_tensor_init(self, target_block, parameters):
"""
All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32).
This function will be overridden in the corresponding optimizer file.
Args:
target_block: the block in which the loss tensor is present
parameters: list of parameter tensors for the optimizer
"""
self._create_accumulators(target_block, parameters)
for param in parameters:
moment1 = self._get_accumulator(self._moment1_acc_str, param)
moment2 = self._get_accumulator(self._moment2_acc_str, param)
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param)
if param.dtype == paddle.float32:
self._param_dict['FP32_LODTensor'].append(param)
self._moment1_dict['FP32_LODTensor'].append(moment1)
self._moment2_dict['FP32_LODTensor'].append(moment2)
self._beta1_pow_acc_dict['FP32_LODTensor'].append(beta1_pow_acc)
self._beta2_pow_acc_dict['FP32_LODTensor'].append(beta2_pow_acc)
elif param.dtype == paddle.float16:
self._param_dict['FP16_LODTensor'].append(param)
self._moment1_dict['FP16_LODTensor'].append(moment1)
self._moment2_dict['FP16_LODTensor'].append(moment2)
self._beta1_pow_acc_dict['FP16_LODTensor'].append(beta1_pow_acc)
self._beta2_pow_acc_dict['FP16_LODTensor'].append(beta2_pow_acc)
if self._multi_precision:
self._master_weight_dict['FP16_LODTensor'].append(
self._master_weights[param.name])
else:
self._master_weight_dict['FP16_LODTensor'] = None
else:
raise ValueError(
"Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR."
)
def _append_optimize_multi_tensor_op(self, target_block,
parameters_and_grads):
"""
For Multi Tensor, append optimize merged_operator to block.
"""
assert isinstance(target_block, framework.Block)
grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
if isinstance(parameters_and_grads, list):
for param_and_grad in parameters_and_grads:
if param_and_grad[1] is None:
continue
if param_and_grad[0].stop_gradient is False:
if param_and_grad[
0].dtype == paddle.float32 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP32_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif param_and_grad[
0].dtype == paddle.float16 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP16_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP16_LODTensor'].append(lr)
else:
for param_and_grad in parameters_and_grads['params']:
if param_and_grad[1] is None:
continue
if param_and_grad[0].stop_gradient is False:
param_grad_dict = dict()
param_grad_dict['params'] = param_and_grad
param_grad_dict.update({
k: v
for k, v in parameters_and_grads.items()
if k != 'params'
})
param_and_grad = self._update_param_group(param_grad_dict)
if param_and_grad[
0].dtype == paddle.float32 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP32_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP32_LODTensor'].append(lr)
elif param_and_grad[
0].dtype == paddle.float16 and param_and_grad[
1].type == core.VarDesc.VarType.LOD_TENSOR:
grad_dict['FP16_LODTensor'].append(param_and_grad[1])
lr = self._create_param_lr(param_and_grad)
lr_dict['FP16_LODTensor'].append(lr)
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
_beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0)
if framework.in_dygraph_mode():
_, _, _, _, _, _ = _C_ops.merged_adam(
self._param_dict[key], grad_dict[key], lr_dict[key],
self._moment1_dict[key], self._moment2_dict[key],
self._beta1_pow_acc_dict[key],
self._beta2_pow_acc_dict[key],
self._master_weight_dict[key], self._param_dict[key],
self._moment1_dict[key], self._moment2_dict[key],
self._beta1_pow_acc_dict[key],
self._beta2_pow_acc_dict[key],
self._master_weight_dict[key], 'epsilon', self._epsilon,
'beta1', _beta1, 'beta2', _beta2, 'multi_precision',
self._multi_precision)
else:
inputs = {
"Param": self._param_dict[key],
"Grad": grad_dict[key],
"LearningRate": lr_dict[key],
"Moment1": self._moment1_dict[key],
"Moment2": self._moment2_dict[key],
"Beta1Pow": self._beta1_pow_acc_dict[key],
"Beta2Pow": self._beta2_pow_acc_dict[key]
}
outputs = {
"ParamOut": self._param_dict[key],
"Moment1Out": self._moment1_dict[key],
"Moment2Out": self._moment2_dict[key],
"Beta1PowOut": self._beta1_pow_acc_dict[key],
"Beta2PowOut": self._beta2_pow_acc_dict[key]
}
attrs = {
"epsilon": self._epsilon,
"beta1": _beta1,
"beta2": _beta2
}
if self._multi_precision:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
target_block.append_op(
type="merged_adam",
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)
return None
def _update_param_group(self, parameters):
self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
......
......@@ -218,7 +218,7 @@ class Optimizer(object):
self._param_groups = self._parameter_list
# NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode.
# Optimizer support list: [ paddle.optimizer.Momentum ].
# Optimizer support list: [ paddle.optimizer.Momentum, paddle.optimizer.Adam].
self._use_multi_tensor = None
self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []}
......@@ -684,8 +684,10 @@ class Optimizer(object):
self._create_global_learning_rate()
# NOTE: Multi Tensor support [ Momentum ] for dygraph mode
if self._use_multi_tensor and self.__class__.__name__ in ['Momentum']:
# NOTE: Multi Tensor support [ Momentum, Adam ] for dygraph mode
if self._use_multi_tensor and self.__class__.__name__ in [
'Momentum', 'Adam'
]:
if len(self._param_dict['FP32_LODTensor']) == 0 and len(
self._param_dict['FP16_LODTensor']) == 0:
if isinstance(parameters_and_grads, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册