From 0f30d3a2130294574dcb672c270f88d90c2b488b Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 18 Nov 2019 10:22:24 +0800 Subject: [PATCH] fix dygraph trace bug, test=develop (#21193) --- python/paddle/fluid/dygraph/jit.py | 2 +- .../test_imperative_transformer_sorted_gradient.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 2d51556805..e3d56ad009 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -376,7 +376,7 @@ class TracedLayer(object): if partial_vars is None: return all_vars - return [all_vars[idx] for idx in feed] + return [all_vars[idx] for idx in partial_vars] with scope_guard(self._scope): feeded_var_names = get_feed_fetch(self._feed_names, feed) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index 73adfa7301..c542235e15 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): if i % 2 == 0: outs, traced_layer = TracedLayer.trace( transformer, [enc_inputs, dec_inputs, label, weights]) - outs_static = traced_layer(enc_inputs + dec_inputs + - [label, weights]) + + ins_static = enc_inputs + dec_inputs + [label, weights] + outs_static = traced_layer(ins_static) helper.assertEachVar(outs, outs_static) if program is not None: self.assertTrue( @@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): program = traced_layer.program traced_layer.save_inference_model( - './infer_imperative_transformer') + './infer_imperative_transformer', + feed=range(len(ins_static)), + fetch=range(len(outs_static))) else: outs = transformer(enc_inputs, dec_inputs, label, weights) -- GitLab