diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 520c094798f2a0464dd11950bf657372fcfb73c8..4ba7a164c348c4d922aec4a5ecff909f3edac0cb 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1531,8 +1531,16 @@ class OutScaleForTrainingPass(object): attrs=attrs, inputs=ins, outputs=outs) + + next_op_node = None + if len(in_node.outputs) > 0: + next_op_node = in_node.outputs[0] + graph.link_to(in_node, scale_op_node) graph.link_to(scale_op_node, scale_node) + if next_op_node: + graph.link_to(scale_node, next_op_node) + if not self._is_test: graph.link_to(state_in_node, scale_op_node) graph.link_to(accum_in_node, scale_op_node)