未验证 提交 119cda3d 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] add input EpsilonTensor for adam (#32605)

* add input EpsilonTensor for adam

* update python api

* add unit test

* add npu test

* add more ut
上级 bc379ca3
......@@ -151,6 +151,11 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("EpsilonTensor",
"(Tensor<float32>, optional) If provided, Adam will use this "
"as epsilon, this has a higher priority than attr(epsilon), the "
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter");
......@@ -232,4 +237,13 @@ REGISTER_OP_VERSION(adam)
paddle::framework::compatible::OpVersionDesc().NewAttr(
"multi_precision",
"(bool) Whether to use multi-precision during weight updating.",
false));
false))
.AddCheckpoint(
R"ROC(
Upgrade adam, add 1 dispensable input [EpsilonTensor].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"EpsilonTensor",
"If provided, Adam will use this as epsilon, "
"this has a higher priority than attr(epsilon). "
"For better performance in npu kernel. "));
......@@ -154,7 +154,7 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
......@@ -188,6 +188,15 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
beta2_tensor->numel()));
beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
}
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
epsilon = static_cast<MPDType>(GetAttrFromTensor(epsilon_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
......
......@@ -406,7 +406,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
......@@ -440,6 +440,15 @@ class AdamOpKernel : public framework::OpKernel<T> {
beta2_tensor->numel()));
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
}
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
......
......@@ -80,24 +80,53 @@ class AdamNPUKernel : public framework::OpKernel<T> {
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
}
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
const Tensor* beta1_tensor = nullptr;
const Tensor* beta2_tensor = nullptr;
const Tensor* epsilon_tensor = nullptr;
Tensor beta1_tmp(framework::proto::VarType::FP32);
Tensor beta2_tmp(framework::proto::VarType::FP32);
Tensor epsilon_tmp(framework::proto::VarType::FP32);
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(Beta1Tensor) size must be 1, but get %d",
beta1_tensor->numel()));
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
} else {
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
beta1_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta1_tmp, beta1);
beta1_tensor = &beta1_tmp;
}
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(Beta2Tensor) size must be 1, but get %d",
beta2_tensor->numel()));
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
} else {
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
beta2_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta2_tmp, beta2);
beta2_tensor = &beta2_tmp;
}
if (ctx.HasInput("EpsilonTensor")) {
epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
PADDLE_ENFORCE_EQ(epsilon_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(EpsilonTensor) size must be 1, but get %d",
epsilon_tensor->numel()));
} else {
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
epsilon_tmp.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&epsilon_tmp, epsilon);
epsilon_tensor = &epsilon_tmp;
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
VLOG(3) << "param.numel(): " << param->numel();
......@@ -113,19 +142,6 @@ class AdamNPUKernel : public framework::OpKernel<T> {
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
// reshape
Tensor beta1_tensor(framework::proto::VarType::FP32);
beta1_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta1_tensor, beta1);
Tensor beta2_tensor(framework::proto::VarType::FP32);
beta2_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta2_tensor, beta2);
Tensor epsilon_tensor(framework::proto::VarType::FP32);
TensorFromVector(std::vector<T>{epsilon},
ctx.template device_context<platform::DeviceContext>(),
&epsilon_tensor);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -133,7 +149,7 @@ class AdamNPUKernel : public framework::OpKernel<T> {
NpuOpRunner("ApplyAdamD",
{
*param, *mom1, *mom2, *beta1_pow, *beta2_pow, *lr,
beta1_tensor, beta2_tensor, epsilon_tensor, *grad,
*beta1_tensor, *beta2_tensor, *epsilon_tensor, *grad,
},
{
*param_out, *mom1_out, *mom2_out,
......@@ -159,10 +175,10 @@ class AdamNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::DeviceContext>(), mom2_out);
}
auto runner_m1 =
NpuOpRunner("Mul", {*beta1_pow, beta1_tensor}, {*beta1_pow_out}, {});
NpuOpRunner("Mul", {*beta1_pow, *beta1_tensor}, {*beta1_pow_out}, {});
runner_m1.Run(stream);
auto runner_m2 =
NpuOpRunner("Mul", {*beta2_pow, beta2_tensor}, {*beta2_pow_out}, {});
NpuOpRunner("Mul", {*beta2_pow, *beta2_tensor}, {*beta2_pow_out}, {});
runner_m2.Run(stream);
}
};
......
......@@ -35,8 +35,6 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Adam");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
......@@ -85,6 +83,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
}
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
if (ctx.HasInput("EpsilonTensor")) {
auto* epsilon_tensor = ctx.Input<framework::Tensor>("EpsilonTensor");
epsilon = static_cast<T>(GetAttrFromTensor(epsilon_tensor));
}
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
"Grad", "Adam");
......
......@@ -1890,7 +1890,8 @@ class AdamOptimizer(Optimizer):
beta2 (float|Variable, optional): The exponential decay rate for the 2nd moment estimates.
It should be a float number or a Variable with shape [1] and data type as float32.
The default value is 0.999.
epsilon (float, optional): A small float value for numerical stability.
epsilon (float|Tensor, optional): A small float value for numerical stability.
It should be a float number or a Variable with shape [1] and data type as float32.
The default value is 1e-08.
parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
......@@ -1959,7 +1960,7 @@ class AdamOptimizer(Optimizer):
avg_cost = fluid.layers.mean(cost)
# define beta decay variable
def get_decayed_betas(beta1_init, beta2_init, decay_steps, decay_rate):
def get_decayed_betas(beta1_init, beta2_init, decay_steps, decay_rate, epsilon_init):
global_step = lr_scheduler._decay_step_counter()
beta1 = fluid.layers.create_global_var(
......@@ -1976,6 +1977,13 @@ class AdamOptimizer(Optimizer):
# set persistable for save checkpoints and resume
persistable=True,
name="beta2")
epsilon = fluid.layers.create_global_var(
shape=[1],
value=float(epsilon_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="epsilon")
div_res = global_step / decay_steps
decayed_beta1 = beta1_init * (decay_rate**div_res)
......@@ -1983,13 +1991,14 @@ class AdamOptimizer(Optimizer):
fluid.layers.assign(decayed_beta1, beta1)
fluid.layers.assign(decayed_beta2, beta2)
return beta1, beta2
return beta1, beta2, epsilon
beta1, beta2 = get_decayed_betas(0.9, 0.99, 1e5, 0.9)
beta1, beta2, epsilon = get_decayed_betas(0.9, 0.99, 1e5, 0.9, 1e-8)
adam_optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=0.01,
beta1=beta1,
beta2=beta2)
beta2=beta2,
epsilon=epsilon)
adam_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
......@@ -2099,7 +2108,6 @@ class AdamOptimizer(Optimizer):
"Beta2PowOut": [beta2_pow_acc],
}
attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000
}
......@@ -2112,6 +2120,10 @@ class AdamOptimizer(Optimizer):
inputs['Beta2Tensor'] = self._beta2
else:
attrs['beta2'] = self._beta2
if isinstance(self._epsilon, Variable):
inputs['EpsilonTensor'] = self._epsilon
else:
attrs['epsilon'] = self._epsilon
adam_op = block.append_op(
type=self.type,
......
......@@ -27,7 +27,7 @@ SEED = 2021
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestSGD(OpTest):
class TestAdam(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
......@@ -78,9 +78,61 @@ class TestSGD(OpTest):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
'''
# TODO(zhiqiu): The following test may let 0-3 card down.
# we need to analyze it and open it.
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestAdamWithEpsilonTensor(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
learning_rate = 0.004
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
'Beta1Tensor': np.array([beta1]).astype("float32"),
'Beta2Tensor': np.array([beta2]).astype("float32"),
'EpsilonTensor': np.array([epsilon]).astype("float32"),
}
self.attrs = {'epsilon': epsilon}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, self.attrs)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
......@@ -140,9 +192,93 @@ class TestNet(unittest.TestCase):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred))
self.assertTrue(np.allclose(npu_loss, cpu_loss))
'''
self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-4))
self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-4))
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNetWithEpsilonTensor(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
beta1_init = 0.9
beta2_init = 0.999
epsilon_init = 1e-8
beta1 = fluid.layers.create_global_var(
shape=[1],
value=float(beta1_init),
dtype='float32',
persistable=True,
name="beta1")
beta2 = fluid.layers.create_global_var(
shape=[1],
value=float(beta2_init),
dtype='float32',
persistable=True,
name="beta2")
epsilon = fluid.layers.create_global_var(
shape=[1],
value=float(epsilon_init),
dtype='float32',
persistable=True,
name="epsilon")
adam = fluid.optimizer.Adam(
learning_rate=0.01, beta1=beta1, beta2=beta2, epsilon=epsilon)
adam.minimize(loss)
if run_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-4))
self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-4))
if __name__ == '__main__':
unittest.main()
......@@ -402,6 +402,54 @@ class TestAdamOpBetaVariable(OpTest):
self.check_output()
class TestAdamOpBetaEpsilonVariable(OpTest):
def setUp(self):
'''Test Adam Op with beta as Variable
'''
self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
# The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32")
beta1 = 0.85
beta2 = 0.95
learning_rate = 0.001
epsilon = 1e-8
beta1_pow = beta1**10
beta2_pow = beta2**10
self.inputs = {
'Param': param,
'Grad': grad,
'Moment1': moment1,
'Moment2': moment2,
'LearningRate': np.array([learning_rate]).astype("float32"),
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
'Beta2Pow': np.array([beta2_pow]).astype("float32"),
"Beta1Tensor": np.array([beta1]).astype("float32"),
"Beta2Tensor": np.array([beta2]).astype("float32"),
"EpsilonTensor": np.array([epsilon]).astype("float32"),
}
attributes = {'epsilon': epsilon}
param_out, moment1_out, \
moment2_out = adam_step(self.inputs, attributes)
self.outputs = {
'Moment1Out': moment1_out,
'Moment2Out': moment2_out,
'ParamOut': param_out,
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
}
def test_check_output(self):
self.check_output()
class TestAdamOpV2(unittest.TestCase):
def test_adam_op(self):
place = fluid.CPUPlace()
......@@ -531,5 +579,121 @@ class TestAdamOpV2(unittest.TestCase):
adam.step()
class TestNetWithEpsilonTensor(unittest.TestCase):
def _test(self, place, use_tensor=True, use_fluid_api=True):
paddle.enable_static()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
SEED = 2021
paddle.seed(SEED)
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(
name="label", shape=[32, 1], dtype='int64')
sum = paddle.add(a, b)
z = paddle.pow(sum, 2.0)
fc_1 = fluid.layers.fc(input=z, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
beta1_init = 0.9
beta2_init = 0.999
epsilon_init = 1e-8
if use_tensor:
beta1 = fluid.layers.create_global_var(
shape=[1],
value=float(beta1_init),
dtype='float32',
persistable=True,
name="beta1")
beta2 = fluid.layers.create_global_var(
shape=[1],
value=float(beta2_init),
dtype='float32',
persistable=True,
name="beta2")
epsilon = fluid.layers.create_global_var(
shape=[1],
value=float(epsilon_init),
dtype='float32',
persistable=True,
name="epsilon")
if use_fluid_api:
adam = fluid.optimizer.Adam(
learning_rate=0.01,
beta1=beta1,
beta2=beta2,
epsilon=epsilon)
else:
adam = paddle.optimizer.Adam(
learning_rate=0.01,
beta1=beta1,
beta2=beta2,
epsilon=epsilon)
else:
if use_fluid_api:
adam = fluid.optimizer.Adam(
learning_rate=0.01,
beta1=beta1_init,
beta2=beta2_init,
epsilon=epsilon_init)
else:
adam = fluid.optimizer.Adam(
learning_rate=0.01,
beta1=beta1_init,
beta2=beta2_init,
epsilon=epsilon_init)
adam.minimize(loss)
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(10):
pred_res, loss_res = exe.run(
main_prog,
feed={"a": a_np,
"b": b_np,
"label": label_np},
fetch_list=[prediction, loss])
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(epoch, pred_res[
0], loss_res))
paddle.disable_static()
return pred_res, loss_res
def _test_with_place(self, place):
preds = []
losses = []
for use_tensor in [True, False]:
for use_fluid_api in [True, False]:
pred, loss = self._test(place, use_tensor, use_fluid_api)
preds.append(pred)
losses.append(loss)
for pred in preds:
self.assertTrue(np.allclose(pred, preds[0]))
for loss in losses:
self.assertTrue(np.allclose(loss, losses[0]))
def test_adam_api(self):
# NOTE(zhiqiu): cpu and gpu has different seed, so should compare separatly.
self._test_with_place(paddle.CPUPlace())
if core.is_compiled_with_cuda():
self._test_with_place(paddle.CUDAPlace(0))
if __name__ == "__main__":
unittest.main()
......@@ -58,7 +58,8 @@ class Adam(Optimizer):
beta2 (float|Tensor, optional): The exponential decay rate for the 2nd moment estimates.
It should be a float number or a Tensor with shape [1] and data type as float32.
The default value is 0.999.
epsilon (float, optional): A small float value for numerical stability.
epsilon (float|Tensor, optional): A small float value for numerical stability.
It should be a float number or a Tensor with shape [1] and data type as float32.
The default value is 1e-08.
parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
......@@ -144,12 +145,18 @@ class Adam(Optimizer):
assert beta1 is not None
assert beta2 is not None
assert epsilon is not None
if not 0 <= beta1 < 1:
raise ValueError("Invaild value of beta1, expect beta1 in [0,1).")
if not 0 <= beta2 < 1:
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
if not 0 <= epsilon:
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
if not isinstance(beta1, Variable):
if not 0 <= beta1 < 1:
raise ValueError(
"Invaild value of beta1, expect beta1 in [0,1).")
if not isinstance(beta2, Variable):
if not 0 <= beta2 < 1:
raise ValueError(
"Invaild value of beta2, expect beta2 in [0,1).")
if not isinstance(epsilon, Variable):
if not 0 <= epsilon:
raise ValueError(
"Invaild value of epsilon, expect epsilon >= 0.")
super(Adam, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
......@@ -295,7 +302,6 @@ class Adam(Optimizer):
"Beta2PowOut": [beta2_pow_acc],
}
attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000,
"multi_precision": find_master
......@@ -309,6 +315,10 @@ class Adam(Optimizer):
inputs['Beta2Tensor'] = self._beta2
else:
attrs['beta2'] = self._beta2
if isinstance(self._epsilon, Variable):
inputs['EpsilonTensor'] = self._epsilon
else:
attrs['epsilon'] = self._epsilon
if find_master:
inputs["MasterParam"] = master_weight
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册