diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index d555e9876a08cecb7c32057a61a92e0cc9dacf7b..1a7a82fbfac19b41e8b96c231ca74398f6b2214c 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -38,6 +38,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker from paddle.fluid.incubate.fleet.parameter_server import version from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops +from paddle.fluid.incubate.fleet.parameter_server.ir.public import _has_global_step from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, DistributedStrategy, \ SyncStrategy, AsyncStrategy, HalfAsyncStrategy, GeoStrategy, StrategyFactory @@ -161,9 +162,9 @@ class FleetTranspiler(Fleet): print(trainer_config) - lrs = _get_lr_ops(self._origin_main_program) + lrs = _has_global_step(_get_lr_ops(self._origin_main_program)) - if len(lrs) > 0: + if lrs > 0: kwargs = {"need_global_step": "1"} else: kwargs = {"need_global_step": "0"} @@ -186,14 +187,6 @@ class FleetTranspiler(Fleet): recv_ctx = fleet.compiled_config.get_communicator_recv_context( recv_type=1) - for name, ctx in send_ctx.items(): - print("name: {}, ctx: {}".format(name, ctx)) - - print("==== = ==== =============== ====") - - for name, ctx in recv_ctx.items(): - print("name: {}, ctx: {}".format(name, ctx)) - from paddle.fluid.communicator import Communicator self._communicator = Communicator( trainer_config.mode, kwargs, diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index b96eff19e9b9c5d8e78b85e61b9a69afee106546..f9889997d9e38c98c4a736a62dbc72da7029f337 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -43,6 +43,8 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundR OP_NAME_SCOPE = "op_namescope" CLIP_OP_NAME_SCOPE = "@CLIP" STEP_COUNTER = "@PS_STEP_COUNTER@" +LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@" + OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName() RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC @@ -62,6 +64,17 @@ def _get_lr_ops(program): return lr_ops +def _has_global_step(lr_ops): + if len(lr_ops) > 0: + for idx, op in enumerate(lr_ops): + if op.type != 'increment': + continue + counter = op.input("X")[0] + if counter == LEARNING_RATE_DECAY_COUNTER: + return True + return False + + def is_sparse_op(op): if op.type == "lookup_table" and op.attr('is_sparse') is True and op.attr( 'is_distributed') is False: