未验证 提交 119cda3d 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] add input EpsilonTensor for adam (#32605)

* add input EpsilonTensor for adam

* update python api

* add unit test

* add npu test

* add more ut
上级 bc379ca3
...@@ -151,6 +151,11 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -151,6 +151,11 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"as beta2, this has a higher priority than attr(beta2), the " "as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1].") "shape of this tensor MUST BE [1].")
.AsDispensable(); .AsDispensable();
AddInput("EpsilonTensor",
"(Tensor<float32>, optional) If provided, Adam will use this "
"as epsilon, this has a higher priority than attr(epsilon), the "
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("ParamOut", "(Tensor) Output parameter");
...@@ -232,4 +237,13 @@ REGISTER_OP_VERSION(adam) ...@@ -232,4 +237,13 @@ REGISTER_OP_VERSION(adam)
paddle::framework::compatible::OpVersionDesc().NewAttr( paddle::framework::compatible::OpVersionDesc().NewAttr(
"multi_precision", "multi_precision",
"(bool) Whether to use multi-precision during weight updating.", "(bool) Whether to use multi-precision during weight updating.",
false)); false))
.AddCheckpoint(
R"ROC(
Upgrade adam, add 1 dispensable input [EpsilonTensor].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"EpsilonTensor",
"If provided, Adam will use this as epsilon, "
"this has a higher priority than attr(epsilon). "
"For better performance in npu kernel. "));
...@@ -154,7 +154,7 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> { ...@@ -154,7 +154,7 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
int64_t min_row_size_to_use_multithread = int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread"); ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode"); bool lazy_mode = ctx.Attr<bool>("lazy_mode");
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param"); auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1"); auto* mom1 = ctx.Input<LoDTensor>("Moment1");
...@@ -188,6 +188,15 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> { ...@@ -188,6 +188,15 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
beta2_tensor->numel())); beta2_tensor->numel()));
beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor)); beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
} }
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
epsilon = static_cast<MPDType>(GetAttrFromTensor(epsilon_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel(); << "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel(); VLOG(3) << "param.numel(): " << param->numel();
......
...@@ -406,7 +406,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -406,7 +406,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
int64_t min_row_size_to_use_multithread = int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread"); ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode"); bool lazy_mode = ctx.Attr<bool>("lazy_mode");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param"); auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1"); auto* mom1 = ctx.Input<LoDTensor>("Moment1");
...@@ -440,6 +440,15 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -440,6 +440,15 @@ class AdamOpKernel : public framework::OpKernel<T> {
beta2_tensor->numel())); beta2_tensor->numel()));
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor)); beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
} }
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel(); << "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel(); VLOG(3) << "param.numel(): " << param->numel();
......
...@@ -80,24 +80,53 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -80,24 +80,53 @@ class AdamNPUKernel : public framework::OpKernel<T> {
beta2_pow_out->mutable_data<T>(ctx.GetPlace()); beta2_pow_out->mutable_data<T>(ctx.GetPlace());
} }
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); const Tensor* beta1_tensor = nullptr;
const Tensor* beta2_tensor = nullptr;
const Tensor* epsilon_tensor = nullptr;
Tensor beta1_tmp(framework::proto::VarType::FP32);
Tensor beta2_tmp(framework::proto::VarType::FP32);
Tensor epsilon_tmp(framework::proto::VarType::FP32);
if (ctx.HasInput("Beta1Tensor")) { if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor"); beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(Beta1Tensor) size must be 1, but get %d", "Input(Beta1Tensor) size must be 1, but get %d",
beta1_tensor->numel())); beta1_tensor->numel()));
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor)); } else {
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
beta1_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta1_tmp, beta1);
beta1_tensor = &beta1_tmp;
} }
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) { if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor"); beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input(Beta2Tensor) size must be 1, but get %d", "Input(Beta2Tensor) size must be 1, but get %d",
beta2_tensor->numel())); beta2_tensor->numel()));
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor)); } else {
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
beta2_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta2_tmp, beta2);
beta2_tensor = &beta2_tmp;
} }
if (ctx.HasInput("EpsilonTensor")) {
epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
} else {
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
epsilon_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&epsilon_tmp, epsilon);
epsilon_tensor = &epsilon_tmp;
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel(); << "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel(); VLOG(3) << "param.numel(): " << param->numel();
...@@ -113,19 +142,6 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -113,19 +142,6 @@ class AdamNPUKernel : public framework::OpKernel<T> {
"beta2 pow output size should be 1, but received " "beta2 pow output size should be 1, but received "
"value is:%d.", "value is:%d.",
beta2_pow_out->numel())); beta2_pow_out->numel()));
// reshape
Tensor beta1_tensor(framework::proto::VarType::FP32);
beta1_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta1_tensor, beta1);
Tensor beta2_tensor(framework::proto::VarType::FP32);
beta2_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta2_tensor, beta2);
Tensor epsilon_tensor(framework::proto::VarType::FP32);
TensorFromVector(std::vector<T>{epsilon},
ctx.template device_context<platform::DeviceContext>(),
&epsilon_tensor);
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -133,7 +149,7 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -133,7 +149,7 @@ class AdamNPUKernel : public framework::OpKernel<T> {
NpuOpRunner("ApplyAdamD", NpuOpRunner("ApplyAdamD",
{ {
*param, *mom1, *mom2, *beta1_pow, *beta2_pow, *lr, *param, *mom1, *mom2, *beta1_pow, *beta2_pow, *lr,
beta1_tensor, beta2_tensor, epsilon_tensor, *grad, *beta1_tensor, *beta2_tensor, *epsilon_tensor, *grad,
}, },
{ {
*param_out, *mom1_out, *mom2_out, *param_out, *mom1_out, *mom2_out,
...@@ -159,10 +175,10 @@ class AdamNPUKernel : public framework::OpKernel<T> { ...@@ -159,10 +175,10 @@ class AdamNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::DeviceContext>(), mom2_out); ctx.template device_context<platform::DeviceContext>(), mom2_out);
} }
auto runner_m1 = auto runner_m1 =
NpuOpRunner("Mul", {*beta1_pow, beta1_tensor}, {*beta1_pow_out}, {}); NpuOpRunner("Mul", {*beta1_pow, *beta1_tensor}, {*beta1_pow_out}, {});
runner_m1.Run(stream); runner_m1.Run(stream);
auto runner_m2 = auto runner_m2 =
NpuOpRunner("Mul", {*beta2_pow, beta2_tensor}, {*beta2_pow_out}, {}); NpuOpRunner("Mul", {*beta2_pow, *beta2_tensor}, {*beta2_pow_out}, {});
runner_m2.Run(stream); runner_m2.Run(stream);
} }
}; };
......
...@@ -35,8 +35,6 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -35,8 +35,6 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type()))); framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input", auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Adam"); "Param", "Adam");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad"); // auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
...@@ -85,6 +83,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -85,6 +83,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor"); auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor)); beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
} }
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor));
}
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input", auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
"Grad", "Adam"); "Grad", "Adam");
......
...@@ -1890,7 +1890,8 @@ class AdamOptimizer(Optimizer): ...@@ -1890,7 +1890,8 @@ class AdamOptimizer(Optimizer):
beta2 (float|Variable, optional): The exponential decay rate for the 2nd moment estimates. beta2 (float|Variable, optional): The exponential decay rate for the 2nd moment estimates.
It should be a float number or a Variable with shape [1] and data type as float32. It should be a float number or a Variable with shape [1] and data type as float32.
The default value is 0.999. The default value is 0.999.
epsilon (float, optional): A small float value for numerical stability. epsilon (float|Tensor, optional): A small float value for numerical stability.
It should be a float number or a Variable with shape [1] and data type as float32.
The default value is 1e-08. The default value is 1e-08.
parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \ parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. \ This parameter is required in dygraph mode. \
...@@ -1959,7 +1960,7 @@ class AdamOptimizer(Optimizer): ...@@ -1959,7 +1960,7 @@ class AdamOptimizer(Optimizer):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
# define beta decay variable # define beta decay variable
def get_decayed_betas(beta1_init, beta2_init, decay_steps, decay_rate): def get_decayed_betas(beta1_init, beta2_init, decay_steps, decay_rate, epsilon_init):
global_step = lr_scheduler._decay_step_counter() global_step = lr_scheduler._decay_step_counter()
beta1 = fluid.layers.create_global_var( beta1 = fluid.layers.create_global_var(
...@@ -1976,6 +1977,13 @@ class AdamOptimizer(Optimizer): ...@@ -1976,6 +1977,13 @@ class AdamOptimizer(Optimizer):
# set persistable for save checkpoints and resume # set persistable for save checkpoints and resume
persistable=True, persistable=True,
name="beta2") name="beta2")
epsilon = fluid.layers.create_global_var(
shape=[1],
value=float(epsilon_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="epsilon")
div_res = global_step / decay_steps div_res = global_step / decay_steps
decayed_beta1 = beta1_init * (decay_rate**div_res) decayed_beta1 = beta1_init * (decay_rate**div_res)
...@@ -1983,13 +1991,14 @@ class AdamOptimizer(Optimizer): ...@@ -1983,13 +1991,14 @@ class AdamOptimizer(Optimizer):
fluid.layers.assign(decayed_beta1, beta1) fluid.layers.assign(decayed_beta1, beta1)
fluid.layers.assign(decayed_beta2, beta2) fluid.layers.assign(decayed_beta2, beta2)
return beta1, beta2 return beta1, beta2, epsilon
beta1, beta2 = get_decayed_betas(0.9, 0.99, 1e5, 0.9) beta1, beta2, epsilon = get_decayed_betas(0.9, 0.99, 1e5, 0.9, 1e-8)
adam_optimizer = fluid.optimizer.AdamOptimizer( adam_optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=0.01, learning_rate=0.01,
beta1=beta1, beta1=beta1,
beta2=beta2) beta2=beta2,
epsilon=epsilon)
adam_optimizer.minimize(avg_cost) adam_optimizer.minimize(avg_cost)
fetch_list = [avg_cost] fetch_list = [avg_cost]
...@@ -2099,7 +2108,6 @@ class AdamOptimizer(Optimizer): ...@@ -2099,7 +2108,6 @@ class AdamOptimizer(Optimizer):
"Beta2PowOut": [beta2_pow_acc], "Beta2PowOut": [beta2_pow_acc],
} }
attrs = { attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode, "lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000 "min_row_size_to_use_multithread": 1000
} }
...@@ -2112,6 +2120,10 @@ class AdamOptimizer(Optimizer): ...@@ -2112,6 +2120,10 @@ class AdamOptimizer(Optimizer):
inputs['Beta2Tensor'] = self._beta2 inputs['Beta2Tensor'] = self._beta2
else: else:
attrs['beta2'] = self._beta2 attrs['beta2'] = self._beta2
if isinstance(self._epsilon, Variable):
inputs['EpsilonTensor'] = self._epsilon
else:
attrs['epsilon'] = self._epsilon
adam_op = block.append_op( adam_op = block.append_op(
type=self.type, type=self.type,
......
...@@ -27,7 +27,7 @@ SEED = 2021 ...@@ -27,7 +27,7 @@ SEED = 2021
@unittest.skipIf(not paddle.is_compiled_with_npu(), @unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU") "core is not compiled with NPU")
class TestSGD(OpTest): class TestAdam(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
...@@ -78,9 +78,61 @@ class TestSGD(OpTest): ...@@ -78,9 +78,61 @@ class TestSGD(OpTest):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False) self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
''' @unittest.skipIf(not paddle.is_compiled_with_npu(),
# TODO(zhiqiu): The following test may let 0-3 card down. "core is not compiled with NPU")
# we need to analyze it and open it. class TestAdamWithEpsilonTensor(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.004
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
'Beta1Tensor': np.array([beta1]).astype("float32"),
'Beta2Tensor': np.array([beta2]).astype("float32"),
'EpsilonTensor': np.array([epsilon]).astype("float32"),
}
self.attrs = {'epsilon': epsilon}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, self.attrs)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(), @unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU") "core is not compiled with NPU")
...@@ -140,9 +192,93 @@ class TestNet(unittest.TestCase): ...@@ -140,9 +192,93 @@ class TestNet(unittest.TestCase):
cpu_pred, cpu_loss = self._test(False) cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True) npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred)) self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-4))
self.assertTrue(np.allclose(npu_loss, cpu_loss)) self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-4))
'''
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNetWithEpsilonTensor(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
beta1_init = 0.9
beta2_init = 0.999
epsilon_init = 1e-8
beta1 = fluid.layers.create_global_var(
shape=[1],
value=float(beta1_init),
dtype='float32',
persistable=True,
name="beta1")
beta2 = fluid.layers.create_global_var(
shape=[1],
value=float(beta2_init),
dtype='float32',
persistable=True,
name="beta2")
epsilon = fluid.layers.create_global_var(
shape=[1],
value=float(epsilon_init),
dtype='float32',
persistable=True,
name="epsilon")
adam = fluid.optimizer.Adam(
learning_rate=0.01, beta1=beta1, beta2=beta2, epsilon=epsilon)
adam.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-4))
self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-4))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -402,6 +402,54 @@ class TestAdamOpBetaVariable(OpTest): ...@@ -402,6 +402,54 @@ class TestAdamOpBetaVariable(OpTest):
self.check_output() self.check_output()
class TestAdamOpBetaEpsilonVariable(OpTest):
def setUp(self):
'''Test Adam Op with beta as Variable
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
beta1 = 0.85
beta2 = 0.95
learning_rate = 0.001
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
"Beta1Tensor": np.array([beta1]).astype("float32"),
"Beta2Tensor": np.array([beta2]).astype("float32"),
"EpsilonTensor": np.array([epsilon]).astype("float32"),
}
attributes = {'epsilon': epsilon}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, attributes)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
self.check_output()
class TestAdamOpV2(unittest.TestCase): class TestAdamOpV2(unittest.TestCase):
def test_adam_op(self): def test_adam_op(self):
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -531,5 +579,121 @@ class TestAdamOpV2(unittest.TestCase): ...@@ -531,5 +579,121 @@ class TestAdamOpV2(unittest.TestCase):
adam.step() adam.step()
class TestNetWithEpsilonTensor(unittest.TestCase):
def _test(self, place, use_tensor=True, use_fluid_api=True):
paddle.enable_static()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
SEED = 2021
paddle.seed(SEED)
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
beta1_init = 0.9
beta2_init = 0.999
epsilon_init = 1e-8
if use_tensor:
beta1 = fluid.layers.create_global_var(
shape=[1],
value=float(beta1_init),
dtype='float32',
persistable=True,
name="beta1")
beta2 = fluid.layers.create_global_var(
shape=[1],
value=float(beta2_init),
dtype='float32',
persistable=True,
name="beta2")
epsilon = fluid.layers.create_global_var(
shape=[1],
value=float(epsilon_init),
dtype='float32',
persistable=True,
name="epsilon")
if use_fluid_api:
adam = fluid.optimizer.Adam(
learning_rate=0.01,
beta1=beta1,
beta2=beta2,
epsilon=epsilon)
else:
adam = paddle.optimizer.Adam(
learning_rate=0.01,
beta1=beta1,
beta2=beta2,
epsilon=epsilon)
else:
if use_fluid_api:
adam = fluid.optimizer.Adam(
learning_rate=0.01,
beta1=beta1_init,
beta2=beta2_init,
epsilon=epsilon_init)
else:
adam = fluid.optimizer.Adam(
learning_rate=0.01,
beta1=beta1_init,
beta2=beta2_init,
epsilon=epsilon_init)
adam.minimize(loss)
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(10):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(epoch, pred_res[
0], loss_res))
paddle.disable_static()
return pred_res, loss_res
def _test_with_place(self, place):
preds = []
losses = []
for use_tensor in [True, False]:
for use_fluid_api in [True, False]:
pred, loss = self._test(place, use_tensor, use_fluid_api)
preds.append(pred)
losses.append(loss)
for pred in preds:
self.assertTrue(np.allclose(pred, preds[0]))
for loss in losses:
self.assertTrue(np.allclose(loss, losses[0]))
def test_adam_api(self):
# NOTE(zhiqiu): cpu and gpu has different seed, so should compare separatly.
self._test_with_place(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self._test_with_place(paddle.CUDAPlace(0))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -58,7 +58,8 @@ class Adam(Optimizer): ...@@ -58,7 +58,8 @@ class Adam(Optimizer):
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates. beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
It should be a float number or a Tensor with shape [1] and data type as float32. It should be a float number or a Tensor with shape [1] and data type as float32.
The default value is 0.999. The default value is 0.999.
epsilon (float, optional): A small float value for numerical stability. epsilon (float|Tensor, optional): A small float value for numerical stability.
It should be a float number or a Tensor with shape [1] and data type as float32.
The default value is 1e-08. The default value is 1e-08.
parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
This parameter is required in dygraph mode. \ This parameter is required in dygraph mode. \
...@@ -144,12 +145,18 @@ class Adam(Optimizer): ...@@ -144,12 +145,18 @@ class Adam(Optimizer):
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
if not 0 <= beta1 < 1: if not isinstance(beta1, Variable):
raise ValueError("Invaild value of beta1, expect beta1 in [0,1).") if not 0 <= beta1 < 1:
if not 0 <= beta2 < 1: raise ValueError(
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") "Invaild value of beta1, expect beta1 in [0,1).")
if not 0 <= epsilon: if not isinstance(beta2, Variable):
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") if not 0 <= beta2 < 1:
raise ValueError(
"Invaild value of beta2, expect beta2 in [0,1).")
if not isinstance(epsilon, Variable):
if not 0 <= epsilon:
raise ValueError(
"Invaild value of epsilon, expect epsilon >= 0.")
super(Adam, self).__init__( super(Adam, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
parameters=parameters, parameters=parameters,
...@@ -295,7 +302,6 @@ class Adam(Optimizer): ...@@ -295,7 +302,6 @@ class Adam(Optimizer):
"Beta2PowOut": [beta2_pow_acc], "Beta2PowOut": [beta2_pow_acc],
} }
attrs = { attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode, "lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000, "min_row_size_to_use_multithread": 1000,
"multi_precision": find_master "multi_precision": find_master
...@@ -309,6 +315,10 @@ class Adam(Optimizer): ...@@ -309,6 +315,10 @@ class Adam(Optimizer):
inputs['Beta2Tensor'] = self._beta2 inputs['Beta2Tensor'] = self._beta2
else: else:
attrs['beta2'] = self._beta2 attrs['beta2'] = self._beta2
if isinstance(self._epsilon, Variable):
inputs['EpsilonTensor'] = self._epsilon
else:
attrs['epsilon'] = self._epsilon
if find_master: if find_master:
inputs["MasterParam"] = master_weight inputs["MasterParam"] = master_weight
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册