未验证 提交 77a8a394 编写于 作者: Z zhaoyingli 提交者: GitHub

add adamw cuda kernel (#35020)

* adamw support cuda

* adamw support cuda
上级 cf99c0d5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/adamw_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, typename MT>
__global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT beta1_pow_, MT beta2_pow_, const MT* moment1,
MT* moment1_out, const MT* moment2,
MT* moment2_out, const MT* lr_, const T* grad,
const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = moment1[id];
MT mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T, typename MT>
__global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff,
const MT* beta1_pow_, const MT* beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out,
const MT* lr_, const T* grad, const T* param,
T* param_out, const MT* master_param,
MT* master_param_out, int ndim) {
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T>
__global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_,
const T* beta2_pow_, T* beta1_pow_out,
T* beta2_pow_out) {
*beta1_pow_out = beta1 * beta1_pow_[0];
*beta2_pow_out = beta2 * beta2_pow_[0];
}
template <typename T, typename MT>
__global__ void SparseAdamWCUDAKernelREG(
MT beta1, MT beta2, MT epsilon, MT coeff, const MT beta1_pow,
const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_,
MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_,
T* param_out_, const MT* master_param, MT* master_param_out,
const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode,
int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_;
MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
for (; id < ndim; id += blockDim.x * gridDim.x) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count, id / row_numel);
if (lazy_mode && row_idx < 0) {
return;
} else {
MT mom1 = mom1_[id];
MT mom2 = mom2_[id];
MT p = master_param ? master_param[id] : static_cast<MT>(param_[id]);
MT g = row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel + id % row_numel])
: static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p -
lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory
mom1_out_[id] = mom1;
mom2_out_[id] = mom2;
param_out_[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
}
template <typename T>
class AdamWOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
using MPDType = typename details::MPTypeTrait<T>::Type;
int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
float coeff = ctx.Attr<float>("coeff");
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
auto* mom2 = ctx.Input<LoDTensor>("Moment2");
auto* lr = ctx.Input<LoDTensor>("LearningRate");
auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
TensorToVector(*skip_update_tensor, ctx.device_context(),
&skip_update_vec);
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adamw skip update";
framework::TensorCopy(
*param, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), param_out);
framework::TensorCopy(
*mom1, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom1_out);
framework::TensorCopy(
*mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out);
framework::TensorCopy(
*beta1_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out);
framework::TensorCopy(
*beta2_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out);
return;
}
// if with_decay = false, coeff = 0
bool with_decay = ctx.Attr<bool>("with_decay");
if (!with_decay) {
coeff = static_cast<float>(0.0);
}
MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(Beta1Tensor) size must be 1, but get %d",
beta1_tensor->numel()));
beta1 = static_cast<MPDType>(GetAttrFromTensor(beta1_tensor));
}
MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(Beta2Tensor) size must be 1, but get %d",
beta2_tensor->numel()));
beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
}
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
epsilon = static_cast<MPDType>(GetAttrFromTensor(epsilon_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
const bool multi_precision = ctx.Attr<bool>("multi_precision");
const LoDTensor* master_param = nullptr;
LoDTensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<LoDTensor>("MasterParam");
master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
}
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
: nullptr;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (grad_var->IsType<framework::LoDTensor>()) {
auto* grad = ctx.Input<LoDTensor>("Grad");
// update param and moment
int threads = 512;
int blocks = (param->numel() + threads - 1) / threads;
if (beta1_pow->place() == platform::CPUPlace() &&
beta2_pow->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, param->numel());
if (!use_global_beta_pow) {
// Cpu update
beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<MPDType>()[0];
beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<MPDType>()[0];
}
} else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, param->numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(),
beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
if (grad->rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = grad;
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CUDADeviceContext>(),
*grad, &tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (beta1_pow->place() == platform::CPUPlace() &&
beta2_pow->place() == platform::CPUPlace()) {
int threads = 512;
int ndim = param->numel();
int blocks = (ndim + threads - 1) / threads;
SparseAdamWCUDAKernelREG<
T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, rows, row_numel, grad_merge.rows().size(),
lazy_mode, ndim);
if (!use_global_beta_pow) {
// Update with cpu
beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<MPDType>()[0];
beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<MPDType>()[0];
}
} else {
SparseAdamWFunctor<T, GPUAdamW, MPDType> functor(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, rows, row_numel, grad_merge.rows().size(),
lazy_mode);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<platform::CUDADeviceContext> for_range(
static_cast<const platform::CUDADeviceContext&>(
ctx.device_context()),
param->numel());
for_range(functor);
if (!use_global_beta_pow) {
// update beta1 and beta2
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(),
beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
}
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by adamw_op"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(adamw, ops::AdamWOpCUDAKernel<float>,
ops::AdamWOpCUDAKernel<double>,
ops::AdamWOpCUDAKernel<plat::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -22,6 +22,7 @@ class AdamWOp : public AdamOp { ...@@ -22,6 +22,7 @@ class AdamWOp : public AdamOp {
using AdamOp::AdamOp; using AdamOp::AdamOp;
}; };
struct GPUAdamW;
struct CPUAdamW; struct CPUAdamW;
template <typename T, typename Flavour> template <typename T, typename Flavour>
...@@ -46,6 +47,107 @@ class AdamWFunctor<T, CPUAdamW> { ...@@ -46,6 +47,107 @@ class AdamWFunctor<T, CPUAdamW> {
} }
}; };
template <typename T, typename Flavour, typename MT = T>
class SparseAdamWFunctor;
template <typename T, typename MT>
class SparseAdamWFunctor<T, GPUAdamW, MT> {
private:
MT beta1_;
MT beta2_;
MT epsilon_;
MT coeff_;
const MT* beta1_pow_;
const MT* beta2_pow_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
MT* moment2_out_;
const MT* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const MT* master_param_;
MT* master_param_out_;
const int64_t* rows_;
int64_t row_numel_;
int64_t row_count_;
bool lazy_mode_;
public:
SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff,
const MT* beta1_pow, const MT* beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out, const MT* lr,
const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
const int64_t* rows, int64_t row_numel, int64_t row_count,
bool lazy_mode)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
coeff_(coeff),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
master_param_(master_param),
master_param_out_(master_param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count),
lazy_mode_(lazy_mode) {}
inline HOSTDEVICE void adamw_update(size_t i, MT g) const {
// The following code is the same as dense
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
// Calculation
MT wd = static_cast<MT>(1.0) - coeff_ * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
p = wd * p -
lr * (mom1 /
(sqrt(mom2) + epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = static_cast<T>(p);
if (master_param_out_) {
master_param_out_[i] = p;
}
}
inline HOSTDEVICE void operator()(size_t i) const {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
if (lazy_mode_ && row_idx < 0) {
return;
} else {
MT g = row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_])
: static_cast<MT>(0);
adamw_update(i, g);
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> { class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
public: public:
......
...@@ -118,6 +118,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -118,6 +118,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sgd", {"ParamOut"}}, {"sgd", {"ParamOut"}},
{"adam", {"adam",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"adamw",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"average_accumulates", {"average_accumulates",
{"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates", {"out_sum_1", "out_sum_2", "out_sum_3", "out_num_accumulates",
"out_old_num_accumulates", "out_num_updates"}}, "out_old_num_accumulates", "out_num_updates"}},
......
...@@ -93,33 +93,6 @@ class TestAdamWOp(unittest.TestCase): ...@@ -93,33 +93,6 @@ class TestAdamWOp(unittest.TestCase):
adam = paddle.optimizer.AdamW( adam = paddle.optimizer.AdamW(
0.1, epsilon=-1, parameters=linear.parameters()) 0.1, epsilon=-1, parameters=linear.parameters())
def test_adamw_lr_decay(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
lr = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=10)
wd = 0.1
adam = paddle.optimizer.AdamW(
learning_rate=lr,
parameters=linear.parameters(),
apply_decay_param_fun=lambda name: True,
weight_decay=wd)
for _ in range(2):
out = linear(a)
out.backward()
lr_to_coeff = adam._lr_to_coeff
adam.step()
for i, value in enumerate(lr_to_coeff.values()):
self.assertAlmostEqual(value.numpy()[0], 1.0 - lr() * wd)
self.assertEqual(len(adam._lr_to_coeff), 0)
lr.step()
adam.clear_gradients()
class TestAdamWOpGroup(TestAdamWOp): class TestAdamWOpGroup(TestAdamWOp):
def test_adamw_op_dygraph(self): def test_adamw_op_dygraph(self):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -162,7 +162,6 @@ class AdamW(Adam): ...@@ -162,7 +162,6 @@ class AdamW(Adam):
self._params_name = set() self._params_name = set()
self._apply_decay_param_fun = apply_decay_param_fun self._apply_decay_param_fun = apply_decay_param_fun
self._coeff = coeff self._coeff = coeff
self._lr_to_coeff = dict()
super(AdamW, self).__init__( super(AdamW, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
...@@ -178,9 +177,6 @@ class AdamW(Adam): ...@@ -178,9 +177,6 @@ class AdamW(Adam):
self.type = "adamw" self.type = "adamw"
# now the adamw op doesn't support cuda
if core.is_compiled_with_cuda():
self.type = "adam"
# Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that. # Use _auxiliary_vars together with _set_auxiliary_var/_get_auxiliary_var to achieve that.
self._auxiliary_vars = dict() self._auxiliary_vars = dict()
...@@ -193,64 +189,7 @@ class AdamW(Adam): ...@@ -193,64 +189,7 @@ class AdamW(Adam):
else: else:
return None return None
def _append_decoupled_weight_decay(self, block, param_and_grad):
"""
Add decoupled weight decay op.
parameter = parameter - parameter * coeff * lr
Args:
block: block in which variable is to be created
param_and_grad: (parameters, gradients) pairs,
the parameters need to decay.
Raises:
Exception: The type of coeff and parameter is not consistent.
"""
if isinstance(param_and_grad, dict):
param_and_grad = self._update_param_group(param_and_grad)
param, grad = param_and_grad
if self._apply_decay_param_fun is not None \
and not self._apply_decay_param_fun(param.name):
return
if isinstance(self._learning_rate, float):
learning_rate = self._learning_rate
else:
# NOTE. We add this function to the _append_optimize_op(),
# for we must make sure _create_param_lr() be called after
# optimizer._create_global_learning_rate().
learning_rate = self._create_param_lr(param_and_grad)
with block.program._optimized_guard(
[param, grad]), framework.name_scope('weight decay'):
self._params_name.add(param.name)
# If it has been calculated, the result will be reused.
# NOTE(wangxi): In dygraph mode, apply_gradient will be executed
# every step, so need clear _lr_to_coeff every step,
# we do this in _create_optimization_pass
decay_coeff = self._lr_to_coeff.get(learning_rate, None)
if decay_coeff is None:
# NOTE(wangxi): for pipeline to set device:all
with paddle.static.device_guard(None):
decay_coeff = 1.0 - learning_rate * self._coeff
self._lr_to_coeff[learning_rate] = decay_coeff
find_master = (self._multi_precision and
param.dtype == core.VarDesc.VarType.FP16)
if find_master:
master_weight = self._master_weights[param.name]
scaled_param = master_weight * decay_coeff
paddle.fluid.layers.assign(
input=scaled_param, output=master_weight)
else:
scaled_param = param * decay_coeff
paddle.fluid.layers.assign(input=scaled_param, output=param)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if not core.is_compiled_with_npu():
self._append_decoupled_weight_decay(block, param_and_grad)
return super(AdamW, self)._append_optimize_op(block, param_and_grad)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if isinstance(param_and_grad, dict): if isinstance(param_and_grad, dict):
...@@ -262,6 +201,8 @@ class AdamW(Adam): ...@@ -262,6 +201,8 @@ class AdamW(Adam):
if self._apply_decay_param_fun is not None \ if self._apply_decay_param_fun is not None \
and not self._apply_decay_param_fun(param.name): and not self._apply_decay_param_fun(param.name):
with_decay = False with_decay = False
else:
self._params_name.add(param.name)
moment1 = self._get_accumulator(self._moment1_acc_str, moment1 = self._get_accumulator(self._moment1_acc_str,
param_and_grad[0]) param_and_grad[0])
...@@ -277,19 +218,19 @@ class AdamW(Adam): ...@@ -277,19 +218,19 @@ class AdamW(Adam):
if find_master else None) if find_master else None)
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
# create the adam optimize op # create the adamw optimize op
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
_beta1 = self._beta1 if not isinstance( _beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0) self._beta1, Variable) else self._beta1.numpy().item(0)
_beta2 = self._beta2 if not isinstance( _beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0) self._beta2, Variable) else self._beta2.numpy().item(0)
_, _, _, _, _ = _C_ops.adam( _, _, _, _, _ = _C_ops.adamw(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
1000, 'beta1', _beta1, 'beta2', _beta2) 1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff)
return None return None
...@@ -350,13 +291,6 @@ class AdamW(Adam): ...@@ -350,13 +291,6 @@ class AdamW(Adam):
return adamw_op return adamw_op
def _create_optimization_pass(self, parameters_and_grads):
optimize_ops = super(
AdamW, self)._create_optimization_pass(parameters_and_grads)
# In dygraph mode, clear _lr_to_coeff after applied gradient
self._lr_to_coeff = dict()
return optimize_ops
def __str__(self): def __str__(self):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册