From 2e783663fa52edd66d66adcebbe2e75ecb2e04d9 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 24 Oct 2017 18:56:56 +0800 Subject: [PATCH] Enable to output LoD in fetch_op and check output LoD in the op unit test. --- paddle/operators/fetch_op.cc | 1 + python/paddle/v2/framework/tests/op_test.py | 19 +++++++++++++++---- .../paddle/v2/framework/tests/test_lstm_op.py | 6 +++++- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index c1b3d66bac..c35d7d49e3 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx); + dst_item.set_lod(src_item.lod()); VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; } diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 0fdc21ef51..0f8c61a2ab 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -333,20 +333,31 @@ class OpTest(unittest.TestCase): type(sub_out)) for sub_out_name, expect in sub_out: idx = find_actual(sub_out_name, fetch_list) - actual = outs[idx] + actual_t = np.array(outs[idx]) + expect_t = expect[0] \ + if isinstance(expect, tuple) else expect self.assertTrue( np.allclose( - actual, expect, atol=atol), + actual_t, expect_t, atol=atol), "Output (" + sub_out_name + ") has diff at " + str(place)) + if isinstance(expect, tuple): + self.assertListEqual( + actual_t.lod(), expect[1], "Output (" + sub_out_name + + ") has different lod at " + str(place)) else: idx = find_actual(out_name, fetch_list) - actual = outs[idx] + actual_t = outs[idx] expect = self.outputs[out_name] + expect_t = expect[0] if isinstance(expect, tuple) else expect self.assertTrue( np.allclose( - actual, expect, atol=atol), + actual_t, expect_t, atol=atol), "Output (" + out_name + ") has diff at " + str(place)) + if isinstance(expect, tuple): + self.assertListEqual(actual_t.lod(), expect[1], + "Output (" + out_name + + ") has different lod at " + str(place)) def check_output(self, atol=1e-5): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index bcce8d32c9..93a4e450e9 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -155,7 +155,11 @@ class TestLstmOp(OpTest): 'Weight': w, 'Bias': b } - self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + 'BatchGate': g_sort + } self.attrs = { 'usePeepholes': True, 'isReverse': self.is_reverse, -- GitLab