未验证 提交 0f30d3a2 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix dygraph trace bug, test=develop (#21193)

上级 7269ffe3
...@@ -376,7 +376,7 @@ class TracedLayer(object): ...@@ -376,7 +376,7 @@ class TracedLayer(object):
if partial_vars is None: if partial_vars is None:
return all_vars return all_vars
return [all_vars[idx] for idx in feed] return [all_vars[idx] for idx in partial_vars]
with scope_guard(self._scope): with scope_guard(self._scope):
feeded_var_names = get_feed_fetch(self._feed_names, feed) feeded_var_names = get_feed_fetch(self._feed_names, feed)
......
...@@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ...@@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
if i % 2 == 0: if i % 2 == 0:
outs, traced_layer = TracedLayer.trace( outs, traced_layer = TracedLayer.trace(
transformer, [enc_inputs, dec_inputs, label, weights]) transformer, [enc_inputs, dec_inputs, label, weights])
outs_static = traced_layer(enc_inputs + dec_inputs +
[label, weights]) ins_static = enc_inputs + dec_inputs + [label, weights]
outs_static = traced_layer(ins_static)
helper.assertEachVar(outs, outs_static) helper.assertEachVar(outs, outs_static)
if program is not None: if program is not None:
self.assertTrue( self.assertTrue(
...@@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): ...@@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
program = traced_layer.program program = traced_layer.program
traced_layer.save_inference_model( traced_layer.save_inference_model(
'./infer_imperative_transformer') './infer_imperative_transformer',
feed=range(len(ins_static)),
fetch=range(len(outs_static)))
else: else:
outs = transformer(enc_inputs, dec_inputs, label, weights) outs = transformer(enc_inputs, dec_inputs, label, weights)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册