From d16cb8ca11628c32b697d81d1948937158d0f8e3 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Mon, 1 Apr 2019 22:04:15 +0800 Subject: [PATCH] Polish code --- paddle/fluid/imperative/layer.cc | 6 +++--- .../unittests/test_imperative_transformer.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 0310d0677b..e65ac865b2 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -123,13 +123,13 @@ class Autograd { ready_op->ApplyGrad(); for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) { - const std::vector& ingrads = it.second; + const std::vector& ingrads = it->second; for (int64_t i = ingrads.size() - 1; i >= 0; --i) { if (!ingrads[i]) continue; - if (ready_op->input_vars_[it.first][i]->IsStopGradient()) { + if (ready_op->input_vars_[it->first][i]->IsStopGradient()) { continue; } - OpBase* pre_op = ready_op->pre_ops_[it.first][i]; + OpBase* pre_op = ready_op->pre_ops_[it->first][i]; if (!pre_op) continue; dep_counts[pre_op] -= 1; diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py index 3bdf334973..32abb03e91 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py @@ -1076,20 +1076,17 @@ class TestDygraphTransformer(unittest.TestCase): 4]] = out[k] self.assertTrue( - np.allclose(static_avg_cost_value, dy_avg_cost._numpy())) + np.array_equal(static_avg_cost_value, dy_avg_cost._numpy())) self.assertTrue( - np.allclose(static_sum_cost_value, dy_sum_cost._numpy())) + np.array_equal(static_sum_cost_value, dy_sum_cost._numpy())) self.assertTrue( - np.allclose( - static_predict_value, dy_predict._numpy(), atol=1e-5)) + np.array_equal(static_predict_value, dy_predict._numpy())) self.assertTrue( - np.allclose(static_token_num_value, dy_token_num._numpy())) + np.array_equal(static_token_num_value, dy_token_num._numpy())) for key, value in six.iteritems(static_param_init): - self.assertTrue(np.allclose(value, dy_param_init[key])) + self.assertTrue(np.array_equal(value, dy_param_init[key])) for key, value in six.iteritems(static_param_updated): - self.assertTrue( - np.allclose( - value, dy_param_updated[key], atol=1e-4)) + self.assertTrue(np.array_equal(value, dy_param_updated[key])) if __name__ == '__main__': -- GitLab