提交 2e783663 编写于 作者: D dangqingqing

Enable to output LoD in fetch_op and check output LoD in the op unit test.

上级 fa72e544
...@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase { ...@@ -52,6 +52,7 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx); dst_item.CopyFrom(src_item, platform::CPUPlace(), dev_ctx);
dst_item.set_lod(src_item.lod());
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name; VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
} }
......
...@@ -333,20 +333,31 @@ class OpTest(unittest.TestCase): ...@@ -333,20 +333,31 @@ class OpTest(unittest.TestCase):
type(sub_out)) type(sub_out))
for sub_out_name, expect in sub_out: for sub_out_name, expect in sub_out:
idx = find_actual(sub_out_name, fetch_list) idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx] actual_t = np.array(outs[idx])
expect_t = expect[0] \
if isinstance(expect, tuple) else expect
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=atol), actual_t, expect_t, atol=atol),
"Output (" + sub_out_name + ") has diff at " + "Output (" + sub_out_name + ") has diff at " +
str(place)) str(place))
if isinstance(expect, tuple):
self.assertListEqual(
actual_t.lod(), expect[1], "Output (" + sub_out_name
+ ") has different lod at " + str(place))
else: else:
idx = find_actual(out_name, fetch_list) idx = find_actual(out_name, fetch_list)
actual = outs[idx] actual_t = outs[idx]
expect = self.outputs[out_name] expect = self.outputs[out_name]
expect_t = expect[0] if isinstance(expect, tuple) else expect
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=atol), actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place)) "Output (" + out_name + ") has diff at " + str(place))
if isinstance(expect, tuple):
self.assertListEqual(actual_t.lod(), expect[1],
"Output (" + out_name +
") has different lod at " + str(place))
def check_output(self, atol=1e-5): def check_output(self, atol=1e-5):
places = [core.CPUPlace()] places = [core.CPUPlace()]
......
...@@ -155,7 +155,11 @@ class TestLstmOp(OpTest): ...@@ -155,7 +155,11 @@ class TestLstmOp(OpTest):
'Weight': w, 'Weight': w,
'Bias': b 'Bias': b
} }
self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
'BatchGate': g_sort
}
self.attrs = { self.attrs = {
'usePeepholes': True, 'usePeepholes': True,
'isReverse': self.is_reverse, 'isReverse': self.is_reverse,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册