diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 2d51556805b6724cb7a2117a860dd2771c1f55b9..e3d56ad009f4b3dc6e03f4c1aba05924836a903d 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -376,7 +376,7 @@ class TracedLayer(object): if partial_vars is None: return all_vars - return [all_vars[idx] for idx in feed] + return [all_vars[idx] for idx in partial_vars] with scope_guard(self._scope): feeded_var_names = get_feed_fetch(self._feed_names, feed) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py index 73adfa73016d5e956997f095e91df34c2405ad27..c542235e15f0586f4ea6ecd04e99ce23c24105e4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer_sorted_gradient.py @@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): if i % 2 == 0: outs, traced_layer = TracedLayer.trace( transformer, [enc_inputs, dec_inputs, label, weights]) - outs_static = traced_layer(enc_inputs + dec_inputs + - [label, weights]) + + ins_static = enc_inputs + dec_inputs + [label, weights] + outs_static = traced_layer(ins_static) helper.assertEachVar(outs, outs_static) if program is not None: self.assertTrue( @@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase): program = traced_layer.program traced_layer.save_inference_model( - './infer_imperative_transformer') + './infer_imperative_transformer', + feed=range(len(ins_static)), + fetch=range(len(outs_static))) else: outs = transformer(enc_inputs, dec_inputs, label, weights)