diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2192139f8d5950286691a77333dd8ec35505b033..aed89c67e98d1a3dd718fb3a27b5e08ebccdcb29 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -1717,8 +1717,8 @@ to transpile() call.") lr_ops = [] block = self.origin_program.global_block() for op in block.ops: - if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int( - LR_SCHED_OP_ROLE_ATTR_VALUE): + if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) | int( + LR_SCHED_OP_ROLE_ATTR_VALUE) > 0: lr_ops.append(op) log("append lr op: ", op.type) return lr_ops