diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc index c1b3d66bac4c703ce78b247aadc2975bb146b5b0..c35d7d49e31f6ca11e2b37a455af430aac50a232 100644 --- a/paddle/operators/fetch_op.cc +++ b/paddle/operators/fetch_op.cc @@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase { // FIXME(yuyang18): Should we assume the fetch operator always generate // CPU outputs? dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx); + dst_item.set_lod(src_item.lod()); VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; } diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 0fdc21ef5133d17b33860a0e095574d3136b2fd1..0f8c61a2ab5a9d046f80c5ead89e0e4e8f7422ce 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -333,20 +333,31 @@ class OpTest(unittest.TestCase): type(sub_out)) for sub_out_name, expect in sub_out: idx = find_actual(sub_out_name, fetch_list) - actual = outs[idx] + actual_t = np.array(outs[idx]) + expect_t = expect[0] \ + if isinstance(expect, tuple) else expect self.assertTrue( np.allclose( - actual, expect, atol=atol), + actual_t, expect_t, atol=atol), "Output (" + sub_out_name + ") has diff at " + str(place)) + if isinstance(expect, tuple): + self.assertListEqual( + actual_t.lod(), expect[1], "Output (" + sub_out_name + + ") has different lod at " + str(place)) else: idx = find_actual(out_name, fetch_list) - actual = outs[idx] + actual_t = outs[idx] expect = self.outputs[out_name] + expect_t = expect[0] if isinstance(expect, tuple) else expect self.assertTrue( np.allclose( - actual, expect, atol=atol), + actual_t, expect_t, atol=atol), "Output (" + out_name + ") has diff at " + str(place)) + if isinstance(expect, tuple): + self.assertListEqual(actual_t.lod(), expect[1], + "Output (" + out_name + + ") has different lod at " + str(place)) def check_output(self, atol=1e-5): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index bcce8d32c944a39e6d6aad4c99f8aa152222c3c1..93a4e450e916716e27573d192bace73f271733de 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -155,7 +155,11 @@ class TestLstmOp(OpTest): 'Weight': w, 'Bias': b } - self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + 'BatchGate': g_sort + } self.attrs = { 'usePeepholes': True, 'isReverse': self.is_reverse,