未验证 提交 e1e3e3b4 编写于 作者: L Leo Chen 提交者: GitHub

adam op adds input SkipUpdate (#34075)

* adam add input SkipUpdate

* add unittest

* add npu unittest

* fix xpu compile

* remove param stream
上级 eeae91b6
......@@ -314,6 +314,13 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
#if defined(PADDLE_WITH_XPU)
else if (platform::is_xpu_place(src.place())) { // NOLINT
memory::Copy(dst_place, dst_ptr,
BOOST_GET_CONST(platform::XPUPlace, src.place()), src_ptr,
size);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(src.place())) { // NOLINT
memory::Copy(dst_place, dst_ptr,
......@@ -341,7 +348,7 @@ inline void TensorToVector(const Tensor& src,
BOOST_GET_CONST(platform::CPUPlace, src.place()), src_ptr,
size);
}
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, BOOST_GET_CONST(platform::CUDAPlace, src.place()),
......@@ -349,6 +356,13 @@ inline void TensorToVector(const Tensor& src,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
#if defined(PADDLE_WITH_XPU)
else if (platform::is_xpu_place(src.place())) { // NOLINT
memory::Copy(dst_place, dst_ptr,
BOOST_GET_CONST(platform::XPUPlace, src.place()), src_ptr,
size);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
else if (platform::is_npu_place(src.place())) { // NOLINT
memory::Copy(dst_place, dst_ptr,
......
......@@ -157,6 +157,8 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddInput("SkipUpdate", "(Tensor<bool>, optional), Skip the update or not.")
.AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
......@@ -265,4 +267,10 @@ REGISTER_OP_VERSION(adam)
"In that case, the outputs(Beta1PowOut, Beta2PowOut) will not be "
"used in adam op, "
"and beta_pow will be updated after all adam op in the model.",
false));
false))
.AddCheckpoint(
R"ROC(
Upgrade adam, add 1 dispensable input [SkipUpdate].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"SkipUpdate", "If the value is true, Adam will skip the update."));
......@@ -172,6 +172,42 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
TensorToVector(*skip_update_tensor, ctx.device_context(),
&skip_update_vec);
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adam skip update";
framework::TensorCopy(
*param, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), param_out);
framework::TensorCopy(
*mom1, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom1_out);
framework::TensorCopy(
*mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out);
framework::TensorCopy(
*beta1_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out);
framework::TensorCopy(
*beta2_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out);
return;
}
MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
......
......@@ -414,7 +414,6 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
auto* mom2 = ctx.Input<LoDTensor>("Moment2");
auto* lr = ctx.Input<LoDTensor>("LearningRate");
auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
......@@ -424,6 +423,42 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
TensorToVector(*skip_update_tensor, ctx.device_context(),
&skip_update_vec);
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adam skip update";
framework::TensorCopy(
*param, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), param_out);
framework::TensorCopy(
*mom1, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom1_out);
framework::TensorCopy(
*mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out);
framework::TensorCopy(
*beta1_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out);
framework::TensorCopy(
*beta2_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out);
return;
}
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
......@@ -451,6 +486,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
epsilon_tensor->numel()));
epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
......
......@@ -58,6 +58,42 @@ class AdamNPUKernel : public framework::OpKernel<T> {
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
TensorToVector(*skip_update_tensor, ctx.device_context(),
&skip_update_vec);
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adam skip update";
framework::TensorCopy(
*param, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), param_out);
framework::TensorCopy(
*mom1, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom1_out);
framework::TensorCopy(
*mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out);
framework::TensorCopy(
*beta1_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out);
framework::TensorCopy(
*beta2_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out);
return;
}
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
......
......@@ -59,6 +59,43 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
bool skip_update = false;
if (ctx.HasInput("SkipUpdate")) {
auto* skip_update_tensor = ctx.Input<framework::Tensor>("SkipUpdate");
PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(SkipUpdate) size must be 1, but get %d",
skip_update_tensor->numel()));
std::vector<bool> skip_update_vec;
TensorToVector(*skip_update_tensor, ctx.device_context(),
&skip_update_vec);
skip_update = skip_update_vec[0];
}
// skip_update=true, just copy input to output, and TensorCopy will call
// mutable_data
if (skip_update) {
VLOG(4) << "Adam skip update";
framework::TensorCopy(
param, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &param_out);
framework::TensorCopy(
mom1, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &mom1_out);
framework::TensorCopy(
mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), &mom2_out);
framework::TensorCopy(
beta1_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta1_pow_out);
framework::TensorCopy(
beta2_pow, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(),
beta2_pow_out);
return;
}
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"Tensor holds the wrong size, Expected beta1 pow "
......
......@@ -134,6 +134,60 @@ class TestAdamWithEpsilonTensor(OpTest):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestAdamOpWithSkipUpdate(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.004
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
'Beta1Tensor': np.array([beta1]).astype("float32"),
'Beta2Tensor': np.array([beta2]).astype("float32"),
'EpsilonTensor': np.array([epsilon]).astype("float32"),
"SkipUpdate": np.array([True]).astype("bool"),
}
self.attrs = {'epsilon': epsilon}
self.outputs = {
'Moment1Out': moment1,
'Moment2Out': moment2,
'ParamOut': param,
'Beta1PowOut': self.inputs['Beta1Pow'],
'Beta2PowOut': self.inputs['Beta2Pow'],
}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestAdamOpWithGlobalBetaPow(OpTest):
......
......@@ -501,6 +501,55 @@ class TestAdamOpWithGlobalBetaPow(OpTest):
self.check_output()
class TestAdamOpWithSkipUpdate(OpTest):
def setUp(self):
'''Test Adam Op with global_beta_pow
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
beta1 = 0.85
beta2 = 0.95
learning_rate = 0.001
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
"Beta1Tensor": np.array([beta1]).astype("float32"),
"Beta2Tensor": np.array([beta2]).astype("float32"),
"EpsilonTensor": np.array([epsilon]).astype("float32"),
"SkipUpdate": np.array([True]).astype("bool"),
}
attributes = {'epsilon': epsilon}
self.attrs = {'use_global_beta_pow': True}
# use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty.
self.outputs = {
'Moment1Out': moment1,
'Moment2Out': moment2,
'ParamOut': param,
'Beta1PowOut': self.inputs['Beta1Pow'],
'Beta2PowOut': self.inputs['Beta2Pow'],
}
def test_check_output(self):
self.check_output()
class TestAdamOpV2(unittest.TestCase):
def test_adam_op(self):
place = fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册