diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index aed89c67e98d1a3dd718fb3a27b5e08ebccdcb29..28d7df8e45d4f3209fb6205b76f8d1be8e73392b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) +OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched @@ -1717,8 +1718,10 @@ to transpile() call.") lr_ops = [] block = self.origin_program.global_block() for op in block.ops: - if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) | int( - LR_SCHED_OP_ROLE_ATTR_VALUE) > 0: + role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) + if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ + role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ + int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) log("append lr op: ", op.type) return lr_ops