未验证 提交 797bd40d 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Generalization for Partition and Completion (#35735)

* default dist op

* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2

* unify api

* unify api

* todo

* update distop forward func

* update distop forward func

* auto parallel backward

* update dist op

* autoparallel backward

* add backward for embedding

* temp1

* temp2

* temp3

* temp4

* backward done1

* backward done2

* backward done3

* dist embedding remove mp mode

* dist matmul remove mp mode

* update dist embedding
『

* dist op init1

* dist op init 2

* update unitest

* context remove parallel mode

* partitioner remove parallel mode

* update unitest

* a more general method to support varying mesh in pipeline parallel

* support varying mesh in pipeline parallel

* embedding support varying mesh in pipeline parallel

* matmul support varying mesh in pipeline parallel

* default dist op support varying mesh in pipeline parallel

* dist attribute for startup program

* default dist op support varying mesh in pipeline parallel 2

* partitoner support varying mesh in pipeline parallel

* revise logic for auto compeletion

* revise framework.py

* revise reshard unitest

* revise unitest for parallelize

* chmod

* fixed bug for dist embedding name mapping
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
上级 127488ba
...@@ -24,6 +24,7 @@ from .utils import print_program_with_distributed_attr ...@@ -24,6 +24,7 @@ from .utils import print_program_with_distributed_attr
from .context import get_default_distributed_context from .context import get_default_distributed_context
from .operators import find_best_compatible_distributed_operator_impl from .operators import find_best_compatible_distributed_operator_impl
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.fleet.meta_optimizers.common import OpRole
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
...@@ -600,7 +601,7 @@ def complete_annotation(program, dist_context=None): ...@@ -600,7 +601,7 @@ def complete_annotation(program, dist_context=None):
return program return program
def complete_backward_annotation(auto_parallel_main_prog, dist_context): def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
"""Complete the annotation of vars and ops in the backward phase for parallel program.""" """Complete the annotation of vars and ops in the backward phase for parallel program."""
def _is_grad_var_name(name): def _is_grad_var_name(name):
...@@ -608,24 +609,44 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context): ...@@ -608,24 +609,44 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context):
return True return True
return False return False
grad_start_idx = None def _get_forward_varname_from_grad_varname(grad_var_name):
assert _is_grad_var_name(
grad_var_name), "[{}] is not a grad varnme.".format(grad_var_name)
return grad_var_name[:grad_var_name.find("@GRAD")]
def _get_op_by_id(ops, id):
for op in ops:
if op.desc.id() == id:
return op
return None
if dist_context is None:
dist_context = get_default_distributed_context()
grad_start_idx = -1
for idx, op in enumerate(auto_parallel_main_prog.global_block().ops): for idx, op in enumerate(auto_parallel_main_prog.global_block().ops):
for var_name in op.output_arg_names: if int(op.attr('op_role')) == int(
# TODO: use _is_loss_op to judge int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
if "@GRAD" in var_name and op.type == "fill_constant": core.op_proto_and_checker_maker.OpRole.Loss)):
grad_start_idx = idx assert op.type == "fill_constant"
break grad_start_idx = idx
assert grad_start_idx is not None, "No backward procedure found in this program." break
assert grad_start_idx >= 0, "No backward procedure found in this program."
ops = list(auto_parallel_main_prog.global_block().ops) ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars vars = auto_parallel_main_prog.global_block().vars
for idx in range(grad_start_idx, len(ops)): for idx in range(grad_start_idx, len(ops)):
# complete the loss op
# complete the initial grad loss op
if idx == grad_start_idx: if idx == grad_start_idx:
grad_var = vars[ops[idx].output_arg_names[0]] grad_var = vars[ops[idx].output_arg_names[0]]
grad_var_name = grad_var.name forward_var_name = _get_forward_varname_from_grad_varname(
forward_var_name = grad_var_name[:grad_var_name.find("@GRAD")] grad_var.name)
forward_var = vars[forward_var_name] forward_var = vars[forward_var_name]
# TODO complete other attribte for grad var
tensor_attr = TensorDistributedAttribute(grad_var, dist_context) tensor_attr = TensorDistributedAttribute(grad_var, dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program( process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh() forward_var).get_process_mesh()
...@@ -635,39 +656,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context): ...@@ -635,39 +656,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context):
tensor_attr.set_process_mesh(process_mesh) tensor_attr.set_process_mesh(process_mesh)
dist_context.set_tensor_distributed_attr_for_program(grad_var, dist_context.set_tensor_distributed_attr_for_program(grad_var,
tensor_attr) tensor_attr)
op_attr = OperatorDistributedAttribute(ops[idx], dist_context) op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh) op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
# in the data parallel mode, the loss op followed by scale op.
if ops[idx + 1].type == "scale" and grad_var_name in ops[idx + 1].input_arg_names \
and grad_var_name in ops[idx + 1].output_arg_names:
op_attr = OperatorDistributedAttribute(ops[idx + 1],
dist_context)
op_attr.set_process_mesh(process_mesh)
dist_context.set_op_distributed_attr_for_program(ops[idx + 1],
op_attr)
continue continue
# complete the annotation of the optimizer op. # TODO remove this when dist op handle its own grad scale
# TODO: use _is_optimizer_op to judge # in the data parallel mode, the loss op followed by scale op.
if "Grad" in ops[idx].input_names and "Param" in ops[idx].input_names: if ops[idx].type == "scale" and idx == grad_start_idx + 1:
assert len(ops[idx].input( assert grad_var.name in ops[
"Param")) == 1, "Only support one-to-one now." idx].input_arg_names and grad_var.name in ops[
assert len(ops[idx].input( idx].output_arg_names
"Grad")) == 1, "Only support one-to-one now." grad_var = vars[ops[idx].output_arg_names[0]]
var = vars[ops[idx].input("Param")[0]] forward_var_name = _get_forward_varname_from_grad_varname(
grad_var = vars[ops[idx].input("Grad")[0]] grad_var.name)
forward_var = vars[forward_var_name]
process_mesh = dist_context.get_tensor_distributed_attr_for_program( process_mesh = dist_context.get_tensor_distributed_attr_for_program(
var).get_process_mesh() forward_var).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
var).get_dims_mapping()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context) op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh) op_attr.set_process_mesh(process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue continue
# TODO remove this when dist op handle its own communication
# TODO should distinguish the dp allreduce and mp allreduce
# complete the c_allreduce_sum op for gradient in the data parallel mode. # complete the c_allreduce_sum op for gradient in the data parallel mode.
if ops[idx].type == "c_allreduce_sum" and ops[ if ops[idx].type == "c_allreduce_sum" and ops[
idx].input_arg_names == ops[idx].output_arg_names: idx].input_arg_names == ops[idx].output_arg_names:
...@@ -679,91 +692,123 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context): ...@@ -679,91 +692,123 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context):
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
continue continue
# complete the annotation of grad op # complete the annotation of grad op (xxx_grad op or sum op)
grad_op = ops[idx] grad_op = ops[idx]
for i, op in enumerate(ops[:grad_start_idx]):
match_op = None # xxx_grad op will have a corresponding forward op in gradopidx2opidx
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(op.desc, dist_op_helper = dist_context.get_dist_op_helper()
set(), if grad_op.desc.id() in dist_op_helper.gradopidx2opidx:
[]) # TODO support the case where one forward op corresponding to multiple xxx_grad op
grad_op_input = [] forward_op = _get_op_by_id(
for input_arg_name in grad_op.desc.input_arg_names(): ops[:grad_start_idx],
if "@GRAD" in input_arg_name: dist_op_helper.gradopidx2opidx[grad_op.desc.id()])
name = input_arg_name[:input_arg_name.find("@GRAD") + 5] assert forward_op is not None
grad_op_input.append(name)
else: # op dist attr
grad_op_input.append(input_arg_name) forward_op_attr = dist_context.get_op_distributed_attr_for_program(
forward_op)
# like sum op: the count of grad op will larger than 1
if len(grad_op_desc_list) > 1:
for grad_op_desc in grad_op_desc_list:
if grad_op_input == grad_op_desc.input_arg_names() \
and grad_op.desc.type() == grad_op_desc.type():
match_op = op
break
elif len(grad_op_desc_list) == 1:
if grad_op_input == grad_op_desc_list[0].input_arg_names() \
and grad_op.desc.type() == grad_op_desc_list[0].type():
match_op = op
if match_op is not None:
op_attr = dist_context.get_op_distributed_attr_for_program(op)
grad_op_attr = OperatorDistributedAttribute(grad_op,
dist_context)
grad_op_attr.set_process_mesh(op_attr.get_process_mesh())
for var_name in grad_op.input_arg_names:
if "@GRAD" in var_name:
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
vars[var_name]).get_dims_mapping()
grad_op_attr.set_input_dims_mapping(var_name,
dims_mapping)
else:
dims_mapping = op_attr.get_input_dims_mapping(var_name)
grad_op_attr.set_input_dims_mapping(var_name,
dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
for var_name in grad_op.output_arg_names:
if "@GRAD" in var_name:
forward_var = vars[var_name[:var_name.find("@GRAD")]]
tensor_attr = TensorDistributedAttribute(vars[var_name],
dist_context)
process_mesh = grad_op_attr.get_process_mesh()
dims_mapping = grad_op_attr.get_input_dims_mapping(
forward_var.name)
tensor_attr.set_process_mesh(process_mesh)
tensor_attr.set_dims_mapping(dims_mapping)
dist_context.set_tensor_distributed_attr_for_program(
vars[var_name], tensor_attr)
break
# complete the annotation of sum op for multiple renamed grad var
if grad_op.type == "sum" and all(
map(_is_grad_var_name, grad_op.input_arg_names)):
assert len(grad_op.output_arg_names
) == 1, "The output count of sum op should be one."
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
grad_op_attr.set_process_mesh(forward_op_attr.get_process_mesh())
for var_name in grad_op.input_arg_names: for var_name in grad_op.input_arg_names:
if "@GRAD" in var_name: if "@GRAD" in var_name:
forward_var = vars[var_name[:var_name.find("@GRAD")]]
dims_mapping = dist_context.get_tensor_distributed_attr_for_program( dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping() vars[var_name]).get_dims_mapping()
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
else:
dims_mapping = forward_op_attr.get_input_dims_mapping(
var_name)
# TODO fixed here
if dims_mapping == None:
dims_mapping = forward_op_attr.get_output_dims_mapping(
var_name)
assert dims_mapping is not None, "[{}]'s dims_mapping is None".format(
var_name)
grad_op_attr.set_input_dims_mapping(var_name, dims_mapping) grad_op_attr.set_input_dims_mapping(var_name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
# var dist attr
for var_name in grad_op.output_arg_names: for var_name in grad_op.output_arg_names:
forward_var = vars[var_name[:var_name.find("@GRAD")]] if _is_grad_var_name(var_name):
tensor_attr = TensorDistributedAttribute(vars[var_name],
dist_context) forward_var_name = _get_forward_varname_from_grad_varname(
process_mesh = dist_context.get_tensor_distributed_attr_for_program( var_name)
forward_var).get_process_mesh() forward_var = vars[forward_var_name]
dims_mapping = dist_context.get_tensor_distributed_attr_for_program( tensor_attr = TensorDistributedAttribute(vars[var_name],
forward_var).get_dims_mapping() dist_context)
tensor_attr.set_dims_mapping(dims_mapping) process_mesh = grad_op_attr.get_process_mesh()
tensor_attr.set_process_mesh(process_mesh) dims_mapping = grad_op_attr.get_input_dims_mapping(
dist_context.set_tensor_distributed_attr_for_program( forward_var_name)
vars[var_name], tensor_attr) tensor_attr.set_process_mesh(process_mesh)
grad_op_attr.set_process_mesh( tensor_attr.set_dims_mapping(dims_mapping)
dist_context.get_tensor_distributed_attr_for_program( dist_context.set_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()) vars[var_name], tensor_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx
else:
assert grad_op.type == "sum", "got unexpect op [{}]".format(
str(grad_op.type))
assert all(map(_is_grad_var_name, grad_op.input_arg_names))
assert len(grad_op.output_arg_names) == 1
ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping()
ref_forward_var_process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
# output
tensor_attr = TensorDistributedAttribute(
vars[grad_op.output_arg_names[0]], dist_context)
tensor_attr.set_dims_mapping(ref_forward_var_dims_mapping)
tensor_attr.set_process_mesh(ref_forward_var_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_attr)
# op
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
grad_op_attr.set_process_mesh(ref_forward_var_process_mesh)
for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name
grad_op_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op, dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr) grad_op_attr)
def complete_update_annotation(auto_parallel_main_prog, dist_context):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
if dist_context is None:
dist_context = get_default_distributed_context()
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
for idx in range(len(ops)):
# complete the annotation of the optimizer op.
# TODO to add attribute for moment var
if int(ops[idx].attr('op_role')) == int(OpRole.Optimize):
if "Grad" in ops[idx].input_names and "Param" in ops[
idx].input_names:
assert len(ops[idx].input(
"Param")) == 1, "Only support one-to-one now."
assert len(ops[idx].input(
"Grad")) == 1, "Only support one-to-one now."
param = vars[ops[idx].input("Param")[0]]
grad_var = vars[ops[idx].input("Grad")[0]]
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
param).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
param).get_dims_mapping()
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx],
op_attr)
continue
...@@ -51,23 +51,8 @@ class DistributedContext: ...@@ -51,23 +51,8 @@ class DistributedContext:
self._op_distributed_attr_map_for_program = {} self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {} self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {} self._op_distributed_attr_map_for_graph = {}
# The following is a hard code and will be removed in the future self._get_dist_op_helper = DistOpHelper()
self._data_parallel_axis = None
self._model_parallel_axis = None
self._process_mesh = _g_process_mesh_map.get(0, None) self._process_mesh = _g_process_mesh_map.get(0, None)
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
elif self._process_mesh.ndim == 3:
self._data_parallel_axis = 1
self._model_parallel_axis = 2
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def is_initialized_for_program(self): def is_initialized_for_program(self):
return self._is_initialized_for_program return self._is_initialized_for_program
...@@ -120,16 +105,9 @@ class DistributedContext: ...@@ -120,16 +105,9 @@ class DistributedContext:
def set_process_mesh(self, process_mesh): def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh self._process_mesh = process_mesh
if self._process_mesh is not None:
if self._process_mesh.ndim == 1: def get_dist_op_helper(self):
self._data_parallel_axis = 0 return self._get_dist_op_helper
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1
def initialize_distributed_attr_for_program(self, program): def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program: if self._is_initialized_for_program:
...@@ -425,10 +403,93 @@ class DistributedContext: ...@@ -425,10 +403,93 @@ class DistributedContext:
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1 dims_mapping[i] = -1
def _get_data_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._data_parallel_axis, self._process_mesh
def _get_model_parallel_info(self): class DistOpHelper:
# This function is a hard code, and will be obsoleted in the future """
return self._model_parallel_axis, self._process_mesh DistOpHelper is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
...@@ -22,3 +22,4 @@ from . import dist_matmul ...@@ -22,3 +22,4 @@ from . import dist_matmul
from . import dist_reshape from . import dist_reshape
from . import dist_softmax from . import dist_softmax
from . import dist_transpose from . import dist_transpose
from . import dist_default
...@@ -36,10 +36,12 @@ class DistributedOperatorImpl: ...@@ -36,10 +36,12 @@ class DistributedOperatorImpl:
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = False self._backward_implemented = False
def forward(self, dist_ctx, *args, **kwargs): @staticmethod
def forward(dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def backward(self, dist_ctx, *grad_outputs): @staticmethod
def backward(dist_ctx, *grad_outputs, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def get_name(self): def get_name(self):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperator):
def __init__(self, name):
super(DistributedDefault, self).__init__()
self._name = name
register_distributed_operator("default", DistributedDefault("default"))
# Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.")
def is_input_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.")
def update_dims_mapping(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.")
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
varname_mapping = dist_op_helper.get_varname_mapping()
rank_id = dist_op_helper.get_rank_id()
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
assert len(kwargs[input_name]) == len(
src_op.desc.input(input_name)
), "number of tensor for input [{}] is not match".format(input_name)
for output_name in src_op.desc.output_names():
assert output_name in kwargs, "input [{}] is not given".format(
output_name)
assert len(kwargs[output_name]) == len(
src_op.desc.output(output_name)
), "number of tensor for input [{}] is not match".format(
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# param initialization sync
for varname in dist_op_desc.input_arg_names():
if startup_block.has_var(varname) and startup_block.var(
varname
).is_parameter and varname not in dist_op_helper.already_init_sync_vars:
dist_op_helper.already_init_sync_vars.add(varname)
param = startup_block.var(varname)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(
param)
process_mesh = param_dist_attr.get_process_mesh()
dims_mapping = param_dist_attr.get_dims_mapping()
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group:
rank_id = _get_corresponding_rank(process_mesh, rank_id)
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dims_mapping:
pass
else:
group_ranks = _get_comm_group(
process_mesh.process_group, process_mesh.topology,
axis, rank_id)
sync_group = new_process_group(group_ranks)
new_op = startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
# set distributed attribute
op_attr = OperatorDistributedAttribute(new_op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr.set_output_dims_mapping(param.name,
dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
rank_id = dist_op_helper.get_rank_id()
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not main_block.var(
varname).is_parameter:
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.get_process_mesh()
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group:
rank_id = _get_corresponding_rank(process_mesh, rank_id)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(
process_mesh.process_group, process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
break
if need_gradient_allreduce:
allreduce_vars = []
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and main_block.var(
varname).is_parameter:
assert len(
backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
backward_op.desc.input(input_name))
assert varname + "@GRAD" in backward_op.desc.output_arg_names(
), "parameter's grad [{}] not found in the grad op's output".format(
varname + "@GRAD")
assert len(
backward_op.desc.output(input_name + "@GRAD")
) == 1, "parameter grad of grad op should be length 1, but got [{}]".format(
backward_op.desc.output(input_name + "@GRAD"))
allreduce_vars.append(
backward_op.desc.output(input_name + "@GRAD")[0])
if len(allreduce_vars) > 0:
for varname in allreduce_vars:
grad_var = main_block.var(varname)
allreduce_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [grad_var]},
outputs={'Out': [grad_var]},
attrs={
'ring_id': dp_group.id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Backward
})
scale_op = main_block.append_op(
type='scale',
inputs={'X': grad_var},
outputs={'Out': grad_var},
attrs={
'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward
})
dims_mapping = ctx.get_tensor_distributed_attr_for_program(
grad_var).get_dims_mapping()
process_mesh = dist_attr.get_process_mesh()
for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr.set_output_dims_mapping(grad_var.name,
dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name,
dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr)
main_block._sync_with_cpp()
register_distributed_operator_impl(
"default", DistributedDefaultImpl0("replicate_parallel"))
...@@ -24,12 +24,14 @@ from ..utils import is_valid_list_index ...@@ -24,12 +24,14 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group from ..process import new_process_group
from ..utils import _get_comm_group from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperator): class DistributedEmbedding(DistributedOperator):
...@@ -40,6 +42,7 @@ class DistributedEmbedding(DistributedOperator): ...@@ -40,6 +42,7 @@ class DistributedEmbedding(DistributedOperator):
register_distributed_operator("lookup_table_v2", register_distributed_operator("lookup_table_v2",
DistributedEmbedding("embedding")) DistributedEmbedding("embedding"))
register_distributed_operator("c_embedding", DistributedEmbedding("embedding"))
# RowParallel # RowParallel
...@@ -48,7 +51,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -48,7 +51,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
super(DistributedEmbeddingImpl, self).__init__() super(DistributedEmbeddingImpl, self).__init__()
self._name = name self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -102,127 +105,231 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -102,127 +105,231 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op): @staticmethod
def static_handle(dst_block, def forward(ctx, *args, **kwargs):
src_op, """
op_dist_attr, kwargs: inputname_mapping & outputname_mapping
input_name_mapping, """
output_name_mapping,
rank_id=0): dist_op_helper = ctx.get_dist_op_helper()
assert len( main_block = dist_op_helper.get_dst_main_program().global_block()
input_name_mapping startup_block = dist_op_helper.get_dst_startup_program().global_block()
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format( src_op = dist_op_helper.get_cur_src_op()
input_name_mapping) rank_id = dist_op_helper.get_rank_id()
assert len( op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
output_name_mapping assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format( str(src_op))
output_name_mapping)
assert len( # check validation of inputs / outputs
input_name_mapping['Ids'] assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format( assert 'W' in kwargs, "input [{}] is not given".format('W')
input_name_mapping['Ids']) assert 'Out' in kwargs, "output [{}] is not given".format('Out')
assert len(
input_name_mapping['W'] assert len(
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format( kwargs['Ids']
input_name_mapping['W']) ) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
assert len( kwargs['Ids'])
output_name_mapping['Out'] assert len(
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format( kwargs['W']
input_name_mapping['Out']) ) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
kwargs['W'])
Ids_var = dst_block.var(input_name_mapping['Ids'][0]) assert len(
Weight_var = dst_block.var(input_name_mapping['W'][0]) kwargs['Out']
Out_var = dst_block.var(output_name_mapping['Out'][0]) ) == 1, "row_parallel_embedding output Out take 1 variable but got {}".format(
kwargs['Out'])
# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping( Ids_var = main_block.var(kwargs['Ids'][0])
Weight_var.name)[0] Weight_var = main_block.var(kwargs['W'][0])
process_mesh_shape = op_dist_attr.get_process_mesh().topology Out_var = main_block.var(kwargs['Out'][0])
process_mesh_group = op_dist_attr.get_process_mesh().process_group
# got dist attribute info
# caculate embedding offset embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
# TODO generalize here, using cartisian product to allow any dimensional mesh shape Weight_var.name)[0]
mesh_shape = len(process_mesh_shape) assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format( embedding_row_dim_mapping)
process_mesh_shape) process_mesh_shape = op_dist_attr.get_process_mesh().topology
num_partition = process_mesh_shape[embedding_row_dim_mapping] process_mesh_group = op_dist_attr.get_process_mesh().process_group
# TODO generalize here, support any mesh group
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context( # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
)._get_model_parallel_info() if rank_id not in process_mesh_group:
if mesh_shape == 1: rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
if rank_id not in process_mesh_group: rank_id)
assert len(
process_mesh.topology # A generalized method to caculate embedding offset using cartisian product
) == 2, " row_parallel_embedding process mapping only support 2 dimensional process mesh, \ relative_idx = _get_idx_in_axis(process_mesh_group, process_mesh_shape,
but got {}".format(len(process_mesh.topology)) embedding_row_dim_mapping, rank_id)
rank_id = process_mesh_group[
process_mesh.process_group.index(rank_id) % per_part_size = Weight_var.shape[0]
process_mesh_shape[0]] relative_idx = relative_idx * per_part_size
relative_idx = process_mesh_group.index(rank_id)
# TODO caculate ring id
parallel_axis = embedding_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
group = new_process_group(group_ranks)
# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')
intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var)
check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
c_embedding_op = main_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
# use_model_parallel
c_allreduce_sum_op = main_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block,
op_dist_attr)
# param initialization sync
assert Weight_var.name not in dist_op_helper.already_init_sync_vars
dist_op_helper.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh()
dim_mapping = param_dist_attr.get_dims_mapping()
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else: else:
relative_idx = rank_id % num_partition group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': sync_group.id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
per_part_size = Weight_var.shape[0] # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
relative_idx = relative_idx * per_part_size if rank_id not in dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(),
rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out@GRAD' in kwargs, "input [{}] is not given".format('Out')
assert 'W@GRAD' in kwargs, "output [{}] is not given".format('W@GRAD')
assert len(
kwargs['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Ids'])
assert len(
kwargs['W']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['W'])
assert len(
kwargs['Out@GRAD']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
kwargs['Out'])
assert len(
kwargs['W@GRAD']
) == 1, "row_parallel_embedding output Ids take 1 variable but got {}".format(
kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0])
process_mesh = dist_attr.get_process_mesh()
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
# TODO caculate ring id
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, process_mesh.topology,
model_parallel_axis, rank_id) batch_size_axis, rank_id)
group = new_process_group(group_ranks) dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'], if need_gradient_allreduce:
'c_embedding') W_Grad_var = main_block.var(kwargs['W@GRAD'][0])
allreduce_op = main_block.append_op(
intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0,
Out_var)
check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')
c_embedding_op = dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})
# use_model_parallel
c_allreduce_sum_op = dst_block.append_op(
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]}, inputs={'X': [W_Grad_var]},
outputs={'Out': [Out_var]}, outputs={'Out': [W_Grad_var]},
attrs={ attrs={
'ring_id': group.id, 'ring_id': dp_group.id,
'use_calc_stream': True, 'use_calc_stream': True,
'use_model_parallel': True, OP_ROLE_KEY: OpRole.Backward
}) })
scale_op = main_block.append_op(
type='scale',
inputs={'X': W_Grad_var},
outputs={'Out': W_Grad_var},
attrs={'scale': 1.0 / dp_degree,
OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp()
# copy serial op's dist_attr to dist op's dist_attr dims_mapping = ctx.get_tensor_distributed_attr_for_program(
copy_distributed_attr_for_dist_op(c_embedding_op, dst_block, W_Grad_var).get_dims_mapping()
op_dist_attr) process_mesh = dist_attr.get_process_mesh()
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, dst_block, for op in [allreduce_op, scale_op]:
op_dist_attr) op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
if in_dygraph_mode(): op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
raise NotImplementedError( op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping)
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format( ctx.set_op_distributed_attr_for_program(op, op_attr)
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("lookup_table_v2", register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel")) DistributedEmbeddingImpl("row_parallel"))
register_distributed_operator_impl("c_embedding",
DistributedEmbeddingImpl("row_parallel"))
...@@ -42,7 +42,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -42,7 +42,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
super(DistributedReshapeImpl0, self).__init__() super(DistributedReshapeImpl0, self).__init__()
self._name = name self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -97,82 +97,72 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -97,82 +97,72 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op): @staticmethod
def static_handle(dst_block, def forward(ctx, *args, **kwargs):
src_op, """
op_dist_attr, kwargs: inputname_mapping & outputname_mapping
input_name_mapping, """
output_name_mapping,
rank_id=0): dist_op_helper = ctx.get_dist_op_helper()
assert len( main_block = dist_op_helper.get_dst_main_program().global_block()
input_name_mapping src_op = dist_op_helper.get_cur_src_op()
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( rank_id = dist_op_helper.get_rank_id()
input_name_mapping) op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
assert len( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
output_name_mapping str(src_op))
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping) # check validation of inputs / outputs
assert len( for input_name in src_op.desc.input_names():
input_name_mapping['X'] assert input_name in kwargs, "input [{}] is not given".format(
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( input_name)
input_name_mapping['X']) assert len(kwargs[input_name]) == len(
assert len( src_op.desc.input(input_name)
input_name_mapping['ShapeTensor'] ), "number of tensor for input [{}] is not match".format(input_name)
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( for output_name in src_op.desc.output_names():
input_name_mapping['ShapeTensor']) assert output_name in kwargs, "input [{}] is not given".format(
assert len( output_name)
input_name_mapping['Shape'] assert len(kwargs[output_name]) == len(
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( src_op.desc.output(output_name)
input_name_mapping['Shape']) ), "number of tensor for input [{}] is not match".format(
assert len( output_name)
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( X_var = main_block.var(kwargs['X'][0])
input_name_mapping['Out']) Out_var = main_block.var(kwargs['Out'][0])
assert len( XShape_var = main_block.var(kwargs['XShape'][0])
output_name_mapping['XShape'] shape_list = src_op.desc.attr("shape")
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( ShapeTensor_var_list = []
input_name_mapping['XShape']) for name in kwargs['ShapeTensor']:
ShapeTensor_var_list.append(name)
X_var = dst_block.var(input_name_mapping['X'][0]) Shape_var_list = []
Out_var = dst_block.var(output_name_mapping['Out'][0]) for name in kwargs['Shape']:
XShape_var = dst_block.var(output_name_mapping['XShape'][0]) Shape_var_list.append(name)
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = [] # got dist attribute info
for name in input_name_mapping['ShapeTensor']: dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
ShapeTensor_var_list.append(name) process_mesh_shape = op_dist_attr.get_process_mesh().topology
Shape_var_list = []
for name in input_name_mapping['Shape']: # modify target shape
Shape_var_list.append(name) for idx, axis in enumerate(dim_mapping):
if axis >= 0:
# got dist attribute info if len(shape_list) > idx:
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) shape_list[idx] = shape_list[idx] // process_mesh_shape[
process_mesh_shape = op_dist_attr.get_process_mesh().topology axis]
# modify target shape # create op
for idx, axis in enumerate(dim_mapping): new_op_desc = main_block.desc.append_op()
if axis >= 0: new_op_desc.copy_from(src_op.desc)
if len(shape_list) > idx: new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
shape_list[idx] = shape_list[idx] // process_mesh_shape[ new_op_desc.set_input('Shape', Shape_var_list)
axis] new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
# create op new_op_desc.set_output('Out', [Out_var.name])
new_op_desc = dst_block.desc.append_op() new_op_desc._set_attr('shape', shape_list)
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) main_block._sync_with_cpp()
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name]) @staticmethod
new_op_desc.set_output('XShape', [XShape_var.name]) def backward(ctx, *args, **kwargs):
new_op_desc.set_output('Out', [Out_var.name]) pass
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
class DistributedReshapeImpl1(DistributedOperatorImpl): class DistributedReshapeImpl1(DistributedOperatorImpl):
...@@ -180,7 +170,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -180,7 +170,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
super(DistributedReshapeImpl1, self).__init__() super(DistributedReshapeImpl1, self).__init__()
self._name = name self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -235,82 +225,72 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -235,82 +225,72 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return changed return changed
def forward(self, serial_op): @staticmethod
def static_handle(dst_block, def forward(ctx, *args, **kwargs):
src_op, """
op_dist_attr, kwargs: inputname_mapping & outputname_mapping
input_name_mapping, """
output_name_mapping,
rank_id=0): dist_op_helper = ctx.get_dist_op_helper()
assert len( main_block = dist_op_helper.get_dst_main_program().global_block()
input_name_mapping src_op = dist_op_helper.get_cur_src_op()
) == 3, "Dist op of Reshape take 3 inputs variable but got {}".format( rank_id = dist_op_helper.get_rank_id()
input_name_mapping) op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
assert len( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
output_name_mapping str(src_op))
) == 2, "Dist op of Reshape take 2 inputs variable but got {}".format(
output_name_mapping) # check validation of inputs / outputs
assert len( for input_name in src_op.desc.input_names():
input_name_mapping['X'] assert input_name in kwargs, "input [{}] is not given".format(
) == 1, "Dist op of Reshape input X take 1 variable but got {}".format( input_name)
input_name_mapping['X']) assert len(kwargs[input_name]) == len(
assert len( src_op.desc.input(input_name)
input_name_mapping['ShapeTensor'] ), "number of tensor for input [{}] is not match".format(input_name)
) <= 1, "Dist op of Reshape input ShapeTensor take 0 or 1 variable but got {}".format( for output_name in src_op.desc.output_names():
input_name_mapping['ShapeTensor']) assert output_name in kwargs, "input [{}] is not given".format(
assert len( output_name)
input_name_mapping['Shape'] assert len(kwargs[output_name]) == len(
) <= 1, "Dist op of Reshape input Shape take 0 or 1 variable but got {}".format( src_op.desc.output(output_name)
input_name_mapping['Shape']) ), "number of tensor for input [{}] is not match".format(
assert len( output_name)
output_name_mapping['Out']
) == 1, "Dist op of Reshape input Out take 1 variable but got {}".format( X_var = main_block.var(kwargs['X'][0])
input_name_mapping['Out']) Out_var = main_block.var(kwargs['Out'][0])
assert len( XShape_var = main_block.var(kwargs['XShape'][0])
output_name_mapping['XShape'] shape_list = src_op.desc.attr("shape")
) == 1, "Dist op of Reshape input XShape take 1 variable but got {}".format( ShapeTensor_var_list = []
input_name_mapping['XShape']) for name in kwargs['ShapeTensor']:
ShapeTensor_var_list.append(name)
X_var = dst_block.var(input_name_mapping['X'][0]) Shape_var_list = []
Out_var = dst_block.var(output_name_mapping['Out'][0]) for name in kwargs['Shape']:
XShape_var = dst_block.var(output_name_mapping['XShape'][0]) Shape_var_list.append(name)
shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = [] # got dist attribute info
for name in input_name_mapping['ShapeTensor']: dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
ShapeTensor_var_list.append(name) process_mesh_shape = op_dist_attr.get_process_mesh().topology
Shape_var_list = []
for name in input_name_mapping['Shape']: # modify target shape
Shape_var_list.append(name) for idx, axis in enumerate(dim_mapping):
if axis >= 0:
# got dist attribute info if len(shape_list) > idx:
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) shape_list[idx] = shape_list[idx] // process_mesh_shape[
process_mesh_shape = op_dist_attr.get_process_mesh().topology axis]
# modify target shape # create op
for idx, axis in enumerate(dim_mapping): new_op_desc = main_block.desc.append_op()
if axis >= 0: new_op_desc.copy_from(src_op.desc)
if len(shape_list) > idx: new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
shape_list[idx] = shape_list[idx] // process_mesh_shape[ new_op_desc.set_input('Shape', Shape_var_list)
axis] new_op_desc.set_input('X', [X_var.name])
new_op_desc.set_output('XShape', [XShape_var.name])
# create op new_op_desc.set_output('Out', [Out_var.name])
new_op_desc = dst_block.desc.append_op() new_op_desc._set_attr('shape', shape_list)
new_op_desc.copy_from(src_op.desc)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) main_block._sync_with_cpp()
new_op_desc.set_input('Shape', Shape_var_list)
new_op_desc.set_input('X', [X_var.name]) @staticmethod
new_op_desc.set_output('XShape', [XShape_var.name]) def backward(ctx, *args, **kwargs):
new_op_desc.set_output('Out', [Out_var.name]) pass
new_op_desc._set_attr('shape', shape_list)
dst_block._sync_with_cpp()
if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle
register_distributed_operator_impl("reshape2", register_distributed_operator_impl("reshape2",
......
...@@ -37,6 +37,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -37,6 +37,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__() super(DistributedSoftmaxImpl, self).__init__()
self._name = name self._name = name
self._forward_implemented = False
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -86,6 +88,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -86,6 +88,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return changed return changed
@staticmethod
def backward(ctx, *args, **kwargs):
pass
register_distributed_operator_impl( register_distributed_operator_impl(
"softmax", DistributedSoftmaxImpl("replicate_last_axis")) "softmax", DistributedSoftmaxImpl("replicate_last_axis"))
...@@ -37,6 +37,8 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -37,6 +37,8 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__() super(DistributedTranspose2Impl, self).__init__()
self._name = name self._name = name
self._forward_implemented = False
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """ """ No restriction for now. """
...@@ -82,6 +84,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -82,6 +84,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return changed return changed
@staticmethod
def backward(ctx, *args, **kwargs):
pass
register_distributed_operator_impl( register_distributed_operator_impl(
"transpose2", DistributedTranspose2Impl("same_mapping_transpose")) "transpose2", DistributedTranspose2Impl("same_mapping_transpose"))
...@@ -94,10 +94,8 @@ class AutoParallelizer: ...@@ -94,10 +94,8 @@ class AutoParallelizer:
# The last step: remove all distributed attributes to be compatiable # The last step: remove all distributed attributes to be compatiable
# with inference. # with inference.
self._remove_distributed_attrs(partitioned_main_prog) self._remove_distributed_attrs(partitioned_main_prog)
complete_backward_annotation(partitioned_main_prog, self._dist_context)
make_data_unshard(partitioned_main_prog, partitioned_startup_prog) make_data_unshard(partitioned_main_prog, partitioned_startup_prog)
reshard(partitioned_main_prog, partitioned_startup_prog, rank, reshard(partitioned_main_prog, partitioned_startup_prog, rank,
self._dist_context) self._dist_context)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import threading import threading
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
from .interface import _g_process_mesh_map
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -171,7 +172,9 @@ def _get_comm_group(processes, shape, axis, rank): ...@@ -171,7 +172,9 @@ def _get_comm_group(processes, shape, axis, rank):
""" """
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous # tricks to support processes mesh when it is not start with 0 or continuous
assert rank in processes, "rank [{}] is NOT in processes group {}".format(
rank, processes)
rank_relatvie = processes.index(rank) rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinate = _linear_idx2coordinate(shape, rank_relatvie)
coordinates_in_group = [coordinate[:] for i in range(shape[axis])] coordinates_in_group = [coordinate[:] for i in range(shape[axis])]
...@@ -189,6 +192,25 @@ def _get_comm_group(processes, shape, axis, rank): ...@@ -189,6 +192,25 @@ def _get_comm_group(processes, shape, axis, rank):
return sorted(ranks_in_group) return sorted(ranks_in_group)
def _get_idx_in_axis(processes, shape, axis, rank):
"""
Given a rank and the processes mesh the rank belongs to,
compute the index of the rank in given axis.
Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3].
the index of rank 22 are:
in axis 0: 1
in axis 1: 1
in axis 2: 2
"""
# NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous
# tricks to support processes mesh when it is not start with 0 or continuous
rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(shape, rank_relatvie)
return coordinate[axis]
def _coordinate2linear_idx(mesh_shape, coordinate): def _coordinate2linear_idx(mesh_shape, coordinate):
""" """
convert a coordinate in multidimensional mesh space into a scala idx in linear space. convert a coordinate in multidimensional mesh space into a scala idx in linear space.
...@@ -279,6 +301,27 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): ...@@ -279,6 +301,27 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
return coordinate return coordinate
def _get_corresponding_rank(target_mesh, rank):
# TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case.
# we assume that all mesh are evenly divide from a parent mesh and should have same size.
# to revise this in future.
coordinate = None
for key, mesh in _g_process_mesh_map.items():
if key == 0:
continue
if rank in mesh.process_group and mesh.topology == target_mesh.topology:
coordinate = _linear_idx2coordinate(mesh.topology,
mesh.process_group.index(rank))
break
assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
rank)
return target_mesh.process_group[_coordinate2linear_idx(mesh.topology,
coordinate)]
def _get_unshard_dist_shape(var, dist_attr): def _get_unshard_dist_shape(var, dist_attr):
var_shape = var.shape var_shape = var.shape
mapping = dist_attr.get_dims_mapping() mapping = dist_attr.get_dims_mapping()
......
...@@ -1051,7 +1051,8 @@ def _append_backward_ops_(block, ...@@ -1051,7 +1051,8 @@ def _append_backward_ops_(block,
grad_to_var, grad_to_var,
callbacks=None, callbacks=None,
input_grad_names_set=None, input_grad_names_set=None,
op_path_dict=None): op_path_dict=None,
distop_context=None):
""" """
Create all grad ops, and insert them into given block Create all grad ops, and insert them into given block
...@@ -1108,6 +1109,10 @@ def _append_backward_ops_(block, ...@@ -1108,6 +1109,10 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
if distop_context is not None:
for op_desc in grad_op_desc:
assert op_desc.id() not in distop_context.gradopidx2opidx
distop_context.gradopidx2opidx[op_desc.id()] = op.desc.id()
# Set device for grad_op according to forward Op # Set device for grad_op according to forward Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
...@@ -1402,7 +1407,8 @@ def append_backward(loss, ...@@ -1402,7 +1407,8 @@ def append_backward(loss,
parameter_list=None, parameter_list=None,
no_grad_set=None, no_grad_set=None,
callbacks=None, callbacks=None,
checkpoints=None): checkpoints=None,
distop_context=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
...@@ -1617,7 +1623,8 @@ def append_backward(loss, ...@@ -1617,7 +1623,8 @@ def append_backward(loss,
grad_to_var, grad_to_var,
callbacks, callbacks,
input_grad_names_set=input_grad_names_set, input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict) op_path_dict=op_path_dict,
distop_context=distop_context, )
grad_info_map = dict() grad_info_map = dict()
......
...@@ -32,6 +32,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel) ...@@ -32,6 +32,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_auto_parallel_parallelizer)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper)
list(APPEND DIST_TEST_OPS test_parallel_class_center_sample) list(APPEND DIST_TEST_OPS test_parallel_class_center_sample)
...@@ -221,6 +222,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) ...@@ -221,6 +222,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_parallelizer)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_mixed_precision) LIST(REMOVE_ITEM TEST_OPS test_mixed_precision)
...@@ -1002,6 +1004,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) ...@@ -1002,6 +1004,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_parallelizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
import paddle.fluid.core as core
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
def mlp_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1])
auto.set_pipeline_stage(1)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
predict = mlp(input)
cost = layers.cross_entropy(input=predict, label=label)
avg_cost = layers.mean(x=cost)
return avg_cost, train_program, start_program
class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
loss, train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
suffix = core.kAutoParallelSuffix()
for block in distributed_main_program.blocks:
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_distributed_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
if __name__ == "__main__":
unittest.main()
...@@ -15,130 +15,16 @@ ...@@ -15,130 +15,16 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import paddle.fluid as fluid
# The following statements are used to satisfy fleet initialization from test_parallel_dygraph_dataparallel import TestMultipleGpus
import os
if os.getenv("CUDA_VISIBLE_DEVICES", None) is None:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
import paddle.fluid.core as core
paddle.enable_static() class TestParallelizer(TestMultipleGpus):
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
# check sharding logic as well as the accuracy with single mode
class MLPLayer(nn.Layer): def test_parallelizer_logic(self):
def __init__(self, self.run_mnist_2gpu('auto_parallel_parallelizer.py')
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
return out
def mlp_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1])
auto.set_pipeline_stage(1)
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
predict = mlp(input)
cost = layers.cross_entropy(input=predict, label=label)
avg_cost = layers.mean(x=cost)
return avg_cost, train_program, start_program
class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
loss, train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
optimizer = fleet.distributed_optimizer(optimizer)
_, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
loss, start_program)
suffix = core.kAutoParallelSuffix()
for block in distributed_main_program.blocks:
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_distributed_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -92,9 +92,9 @@ def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit): ...@@ -92,9 +92,9 @@ def check_tensor_split(prog1, varnames1, prog2, varnames2, axis, nsplit):
def initialization_check(mode, dist_context, dist_startup_prog, def initialization_check(mode, dist_context, dist_startup_prog,
serial_startup_prog, var_need_broadcast): serial_startup_prog, var_need_broadcast, process_mesh,
mp_parallel_axis, dp_parallel_axis):
if 'mp' in mode: if 'mp' in mode:
mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, mp_parallel_axis, process_mesh.topology, mp_parallel_axis,
3) 3)
...@@ -110,7 +110,6 @@ def initialization_check(mode, dist_context, dist_startup_prog, ...@@ -110,7 +110,6 @@ def initialization_check(mode, dist_context, dist_startup_prog,
return False return False
if 'dp' in mode: if 'dp' in mode:
dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, dp_parallel_axis, process_mesh.topology, dp_parallel_axis,
3) 3)
...@@ -359,9 +358,15 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -359,9 +358,15 @@ class TestMLPAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = [] var_need_broadcast = []
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=None,
dp_parallel_axis=0))
def test_mlp_mp(self): def test_mlp_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
...@@ -406,9 +411,15 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -406,9 +411,15 @@ class TestMLPAutoPartitioner(unittest.TestCase):
var_need_broadcast = sorted( var_need_broadcast = sorted(
['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=0,
dp_parallel_axis=None))
# check var and op all have dist_attr in dist_main_program # check var and op all have dist_attr in dist_main_program
self.assertTrue( self.assertTrue(
...@@ -464,9 +475,15 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -464,9 +475,15 @@ class TestMLPAutoPartitioner(unittest.TestCase):
var_need_broadcast = sorted( var_need_broadcast = sorted(
['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0']) ['layer_norm_0.b_0', 'layer_norm_0.w_0', 'linear_1.b_0'])
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=1,
dp_parallel_axis=0))
# check var and op all have dist_attr in dist_main_program # check var and op all have dist_attr in dist_main_program
self.assertTrue( self.assertTrue(
...@@ -635,9 +652,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -635,9 +652,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = [] var_need_broadcast = []
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=None,
dp_parallel_axis=0))
def test_attn_mp(self): def test_attn_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
...@@ -686,9 +709,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -686,9 +709,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = ['linear_3.b_0'] var_need_broadcast = ['linear_3.b_0']
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=0,
dp_parallel_axis=None))
# check var and op all have dist_attr in dist_main_program # check var and op all have dist_attr in dist_main_program
self.assertTrue( self.assertTrue(
...@@ -748,9 +777,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -748,9 +777,15 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
# parameter initialization # parameter initialization
var_need_broadcast = ['linear_3.b_0'] var_need_broadcast = ['linear_3.b_0']
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=1,
dp_parallel_axis=0))
# check var and op all have dist_attr in dist_main_program # check var and op all have dist_attr in dist_main_program
self.assertTrue( self.assertTrue(
...@@ -1043,9 +1078,15 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -1043,9 +1078,15 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
'layer_norm_0.w_0', 'linear_5.b_0' 'layer_norm_0.w_0', 'linear_5.b_0'
]) ])
self.assertTrue( self.assertTrue(
initialization_check(_global_parallel_strategy, dist_context, initialization_check(
dist_startup_prog, serial_startup_prog, _global_parallel_strategy,
var_need_broadcast)) dist_context,
dist_startup_prog,
serial_startup_prog,
var_need_broadcast,
_global_process_mesh,
mp_parallel_axis=1,
dp_parallel_axis=0))
# check var and op all have dist_attr in dist_main_program # check var and op all have dist_attr in dist_main_program
self.assertTrue( self.assertTrue(
...@@ -1117,7 +1158,16 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -1117,7 +1158,16 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
'fill_constant', 'gaussian_random', 'fill_constant', 'fill_constant', 'gaussian_random', 'fill_constant',
'gaussian_random', 'fill_constant', 'gaussian_random', 'gaussian_random', 'fill_constant', 'gaussian_random',
'fill_constant', 'gaussian_random', 'fill_constant', 'fill_constant', 'gaussian_random', 'fill_constant',
'gaussian_random', 'fill_constant', 'fill_constant', 'fill_constant' 'gaussian_random', 'fill_constant', 'fill_constant',
'fill_constant', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast'
] ]
self.assertTrue(dist_ops == ref_ops) self.assertTrue(dist_ops == ref_ops)
......
...@@ -521,7 +521,7 @@ class GPTModel(nn.Layer): ...@@ -521,7 +521,7 @@ class GPTModel(nn.Layer):
def __init__(self, def __init__(self,
vocab_size, vocab_size,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=4,
num_attention_heads=12, num_attention_heads=12,
intermediate_size=3072, intermediate_size=3072,
hidden_act="gelu", hidden_act="gelu",
...@@ -787,6 +787,14 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -787,6 +787,14 @@ class TestGPTPartitioner(unittest.TestCase):
dist_params_grads = partitioner.apply_backward( dist_params_grads = partitioner.apply_backward(
loss, complete_train_program, start_program, loss, complete_train_program, start_program,
auto_parallel_main_prog, auto_parallel_startup_prog) auto_parallel_main_prog, auto_parallel_startup_prog)
with open("./test_auto_parallel_partitioner_serial_main_new.txt",
"w") as fw:
fw.write(str(train_program))
with open("./test_auto_parallel_partitioner_serial_startup_new.txt",
"w") as fw:
fw.write(str(start_program))
optimizer = paddle.fluid.optimizer.AdamOptimizer( optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001, learning_rate=0.00001,
beta1=0.9, beta1=0.9,
...@@ -796,7 +804,17 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -796,7 +804,17 @@ class TestGPTPartitioner(unittest.TestCase):
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog, auto_parallel_main_prog,
auto_parallel_startup_prog) auto_parallel_startup_prog)
from paddle.distributed.auto_parallel.context import set_default_distributed_context
set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
fw.write(str(auto_parallel_main_prog))
with open("./test_auto_parallel_partitioner_startup_new.txt1",
"w") as fw:
fw.write(str(auto_parallel_startup_prog))
# with open("./test_auto_parallel_partitioner_main_completed.txt", "w") as fw:
# from paddle.distributed.auto_parallel.completion import complete_backward_annotation
# complete_backward_annotation(auto_parallel_main_prog)
# fw.write(str(auto_parallel_main_prog))
nrank = 4 nrank = 4
# col parallel # col parallel
weights = [ weights = [
...@@ -826,16 +844,20 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -826,16 +844,20 @@ class TestGPTPartitioner(unittest.TestCase):
'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_6.tmp_2', 'layer_norm_7.tmp_2', 'layer_norm_7.tmp_2',
'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2' 'layer_norm_7.tmp_2', 'layer_norm_8.tmp_2'
] ]
mp_parallel_axis, process_mesh = dist_context._get_model_parallel_info() process_mesh = _global_process_mesh
mp_parallel_axis = 1
dp_parallel_axis = 0
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, mp_parallel_axis, process_mesh.topology, mp_parallel_axis,
3) 3)
mp_ring_id = new_process_group(group_ranks).id mp_ring_id = new_process_group(group_ranks).id
dp_parallel_axis, process_mesh = dist_context._get_data_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, dp_parallel_axis, process_mesh.topology, dp_parallel_axis,
3) 3)
dp_ring_id = new_process_group(group_ranks).id dp_ring_id = new_process_group(group_ranks).id
tensor_parallel_allreduce_vars = sorted([ tensor_parallel_allreduce_vars = sorted([
op.desc.output_arg_names()[0].split("@")[0] op.desc.output_arg_names()[0].split("@")[0]
for op in auto_parallel_main_prog.global_block().ops for op in auto_parallel_main_prog.global_block().ops
......
...@@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto ...@@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP
...@@ -211,7 +210,8 @@ def check_initialization_for_dp(dist_startup_prog): ...@@ -211,7 +210,8 @@ def check_initialization_for_dp(dist_startup_prog):
if op.type == "c_broadcast": if op.type == "c_broadcast":
broadcast_varnames.append(op.output_arg_names[0]) broadcast_varnames.append(op.output_arg_names[0])
return params == need_check_params == broadcast_varnames return sorted(params) == sorted(need_check_params) == sorted(
broadcast_varnames)
class TestMLPReshard(unittest.TestCase): class TestMLPReshard(unittest.TestCase):
...@@ -225,7 +225,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -225,7 +225,6 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, 0) train_program, startup_program, dist_context, 0)
complete_backward_annotation(dist_main_prog, dist_context)
op_need_check = None op_need_check = None
for op in dist_main_prog.global_block().ops: for op in dist_main_prog.global_block().ops:
...@@ -254,7 +253,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -254,7 +253,6 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 1 rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
for key in list(PROCESS_GROUP_MAP.keys()): for key in list(PROCESS_GROUP_MAP.keys()):
del PROCESS_GROUP_MAP[key] del PROCESS_GROUP_MAP[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
...@@ -277,7 +275,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -277,7 +275,6 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 0 rank_id = 0
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# send and recv should not exist in dp scene. # send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto ...@@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
paddle.enable_static() paddle.enable_static()
...@@ -158,7 +157,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -158,7 +157,6 @@ class TestMLPReshard(unittest.TestCase):
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
print(dist_main_prog) print(dist_main_prog)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
print(dist_main_prog) print(dist_main_prog)
print(dist_startup_prog) print(dist_startup_prog)
......
...@@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto ...@@ -25,7 +25,6 @@ import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
paddle.enable_static() paddle.enable_static()
...@@ -187,7 +186,6 @@ class TestMLPReshard(unittest.TestCase): ...@@ -187,7 +186,6 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
complete_backward_annotation(dist_main_prog, dist_context)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# check send and recv result # check send and recv result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册