未验证 提交 8e4ed662 编写于 作者: T tangwei12 提交者: GitHub

fix decay global counter (#26387)

* fix decay global counter

* remove unused print, test=distp0
上级 ce7d5263
......@@ -38,6 +38,7 @@ from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
from paddle.fluid.incubate.fleet.parameter_server import version
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _has_global_step
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, DistributedStrategy, \
SyncStrategy, AsyncStrategy, HalfAsyncStrategy, GeoStrategy, StrategyFactory
......@@ -161,9 +162,9 @@ class FleetTranspiler(Fleet):
print(trainer_config)
lrs = _get_lr_ops(self._origin_main_program)
lrs = _has_global_step(_get_lr_ops(self._origin_main_program))
if len(lrs) > 0:
if lrs > 0:
kwargs = {"need_global_step": "1"}
else:
kwargs = {"need_global_step": "0"}
......@@ -186,14 +187,6 @@ class FleetTranspiler(Fleet):
recv_ctx = fleet.compiled_config.get_communicator_recv_context(
recv_type=1)
for name, ctx in send_ctx.items():
print("name: {}, ctx: {}".format(name, ctx))
print("==== = ==== =============== ====")
for name, ctx in recv_ctx.items():
print("name: {}, ctx: {}".format(name, ctx))
from paddle.fluid.communicator import Communicator
self._communicator = Communicator(
trainer_config.mode, kwargs,
......
......@@ -43,6 +43,8 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundR
OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP"
STEP_COUNTER = "@PS_STEP_COUNTER@"
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
......@@ -62,6 +64,17 @@ def _get_lr_ops(program):
return lr_ops
def _has_global_step(lr_ops):
if len(lr_ops) > 0:
for idx, op in enumerate(lr_ops):
if op.type != 'increment':
continue
counter = op.input("X")[0]
if counter == LEARNING_RATE_DECAY_COUNTER:
return True
return False
def is_sparse_op(op):
if op.type == "lookup_table" and op.attr('is_sparse') is True and op.attr(
'is_distributed') is False:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册