未验证 提交 0f30d3a2 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix dygraph trace bug, test=develop (#21193)

上级 7269ffe3
......@@ -376,7 +376,7 @@ class TracedLayer(object):
if partial_vars is None:
return all_vars
return [all_vars[idx] for idx in feed]
return [all_vars[idx] for idx in partial_vars]
with scope_guard(self._scope):
feeded_var_names = get_feed_fetch(self._feed_names, feed)
......
......@@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
if i % 2 == 0:
outs, traced_layer = TracedLayer.trace(
transformer, [enc_inputs, dec_inputs, label, weights])
outs_static = traced_layer(enc_inputs + dec_inputs +
[label, weights])
ins_static = enc_inputs + dec_inputs + [label, weights]
outs_static = traced_layer(ins_static)
helper.assertEachVar(outs, outs_static)
if program is not None:
self.assertTrue(
......@@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_transformer')
'./infer_imperative_transformer',
feed=range(len(ins_static)),
fetch=range(len(outs_static)))
else:
outs = transformer(enc_inputs, dec_inputs, label, weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册