diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index 5f2b6df493b8fe709e0c6a74b16f870873fedaef..8f1a4de86de0d9c4d053c0f3d203d174d3a63d4f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -342,7 +342,7 @@ class DotPowParser(AscendParserBase): y = self._get_ge_input(self.op.input_arg_names[1]) pow = core.GEOperatorFactory.create_operator( "dotpow" + self._accumulated_op_id(), - "Pow").set_input("x1", x1).set_input("x2", y) + "Pow").set_input("x1", x).set_input("x2", y) return [pow], [[0]] @@ -918,15 +918,15 @@ class ScatterParser(AscendParserBase): scatter_value = core.GEOperatorFactory.create_operator( "scatter" + self._accumulated_op_id(), "TensorScatterAdd").set_input( - "x", x_var).set_input("indices", index_var).set_input( - "updates", updatesi_var) + "x", x).set_input("indices", index).set_input("updates", + updates) else: scatter_value = core.GEOperatorFactory.create_operator( "scatter" + self._accumulated_op_id(), "TensorScatterUpdate").set_input( - "x", x_var).set_input("indices", index_var).set_input( - "updates", updates_var) - return [x_var, index_var, updates_var, scatter_value], [[-1]] + "x", x).set_input("indices", index).set_input("updates", + updates) + return [x, index, updates, scatter_value], [[-1]] class CastParser(AscendParserBase):