提交 1141db81 编写于 作者: Q Qiao Longfei

update test_adam_op

test=develop
上级 96604fda
......@@ -358,6 +358,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
VLOG(3) << "lazy_mode :" << lazy_mode;
if (lazy_mode) {
std::vector<int64_t> id_vector;
size_t row_count = grad_merge.rows().size();
......
......@@ -219,14 +219,25 @@ def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad,
moment2_out = np.zeros(shape=[height, row_numel])
param_out = np.zeros(shape=[height, row_numel])
for idx, row_id in enumerate(rows):
def update_row(row_id, update_value):
moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1
) * np_grad[idx]
) * update_value
moment2_out[row_id] = beta2 * moment2[row_id] + (
1 - beta2) * np.square(np_grad[idx])
1 - beta2) * np.square(update_value)
lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / (
np.sqrt(moment2_out[row_id]) + epsilon))
if lazy_mode:
for idx, row_id in enumerate(rows):
update_row(row_id, np_grad[idx])
else:
for row_id in range(param_out.shape[0]):
update_value = np.zeros(np_grad[0].shape).astype("float32")
if row_id in rows:
update_value = np_grad[rows.index(row_id)]
update_row(row_id, update_value)
return param_out, moment1_out, moment2_out
......@@ -249,6 +260,7 @@ class TestSparseAdamOp(unittest.TestCase):
'Beta2Pow': np.array([beta2**10]).astype("float32"),
"LearningRate": np.full((1), 2.0).astype("float32")
}
self.init_output = np.full((height, row_numel), 0.0).astype("float32")
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
grad_selected_rows = scope.var('Grad').get_selected_rows()
......@@ -286,7 +298,7 @@ class TestSparseAdamOp(unittest.TestCase):
op_args[s] = s
for s in self.outputs:
var = scope.var(s).get_tensor()
var.set(self.outputs[s], place)
var.set(self.init_output, place)
op_args[s] = s
for k in self.attrs:
op_args[k] = self.attrs[k]
......@@ -300,13 +312,9 @@ class TestSparseAdamOp(unittest.TestCase):
actual = np.array(out_var)
actual = actual.reshape([actual.size])
np_array = np_array.reshape([np_array.size])
for idx, row_id in enumerate(self.rows):
j = 0
while j < self.row_numel:
pos = row_id * self.row_numel + j
self.assertLess((actual[pos] - np_array[pos]) / actual[pos],
0.00001)
j += 1
for i in range(np_array.size):
self.assertLess((actual[i] - np_array[i]), 0.00001)
def test_sparse_adam(self):
places = [core.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册