未验证 提交 f59bcb1c 编写于 作者: J JZ-LIANG 提交者: GitHub

[AutoParallel & Science] Miscellaneous improvements (#43139)

* adapt for 10 loss

* partitioner support optimizer
上级 ff1789ca
......@@ -363,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
output_name)
# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
......@@ -371,6 +371,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
if prim_enabled():
......@@ -426,6 +428,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
......
......@@ -107,13 +107,14 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
output_name)
# replicate op in dist program
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc = main_block.desc.append_op()
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# batch dimension synchronization
var_name = src_op.output_arg_names[0]
......
......@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
......@@ -263,14 +263,14 @@ class Partitioner(object):
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
elif int(op.attr('op_role')) == 2:
elif is_optimize_op(op):
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_impl = get_distributed_operator_impl_container(
"default").get_impl(0)
dist_op_impl.backward(self._dist_context, **kinputs, **koutputs)
else:
raise NotImplementedError(
"partitioner only support forward op and backward op, but got {}".
"partitioner only support forward and backward, optimize ops, but got {}".
format(str(op)))
def _is_valid_annotated_program(self, program):
......
......@@ -1099,6 +1099,11 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def is_optimize_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册