diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 687de7adea8389443d0bec9d94d189af852eb9db..9e1ccc5f82752806201cc0676909b778d03dc5d4 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName() +class Topology: + """A 4-D structure to describe the process group.""" + + def __init__(self, axes, dims): + pass + + +class ParallelGrid: + """Initialize each process group.""" + + def __init__(self, topology): + self.build_global_group() + self.build_mp_group() + self.build_sharding_group() + self.build_pp_group() + self.build_dp_group() + + def is_update_op(op): return 'Param' in op.input_names and 'Grad' in op.input_names and \ "LearningRate" in op.input_names