未验证 提交 d0f4465d 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add dist op cost (#44146)

* update comp cost

* add dist default op cost

* add dist fill constant batch size like op cost

* add elewise op cost

* add fill_constant_batch_size_like op cost unittest

* add unittest and remove fill_constant_batch_size_like grad op cost

* add to cmakelist

* fix unittest bug
上级 a54c6953
...@@ -16,10 +16,15 @@ from .base_cost import _g_op_cost_factory ...@@ -16,10 +16,15 @@ from .base_cost import _g_op_cost_factory
from .base_cost import Cost from .base_cost import Cost
from .base_cost import CommContext from .base_cost import CommContext
from .base_cost import build_comm_desc from .base_cost import build_comm_desc
from .base_cost import build_comp_desc_from_op
from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_dp_costs
from .base_cost import build_comp_costs_from_descs
from .tensor_cost import TensorCost from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator from .estimate_cost import CostEstimator
from .comp_op_cost import MatmulV2OpCost from .comp_op_cost import MatmulV2OpCost
from .comp_op_cost import FillConstantBatchSizeLikeOpCost
from .comm_op_cost import SendOpCost from .comm_op_cost import SendOpCost
from .comm_op_cost import RecvOpCost from .comm_op_cost import RecvOpCost
......
...@@ -357,24 +357,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): ...@@ -357,24 +357,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
return 0 return 0
@register_op_cost
class FillConstantBatchSizeLikeGradOpCost(CompOpCost):
OP_TYPE = "fill_constant_batch_size_like_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FillConstantBatchSizeLikeGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GatherOpCost(CompOpCost): class GatherOpCost(CompOpCost):
OP_TYPE = "gather" OP_TYPE = "gather"
......
...@@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype ...@@ -31,6 +31,9 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
__op_not_need_param_init__ = ["while", "cond"] __op_not_need_param_init__ = ["while", "cond"]
...@@ -99,6 +102,74 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -99,6 +102,74 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res.append(cost_mapping)
main_block = backward_op.block
vars = main_block.vars
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl, is_parameter_related
from .common import is_elementwise_op from .common import is_elementwise_op
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -32,6 +32,9 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, ...@@ -32,6 +32,9 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY,
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import _g_op_cost_factory
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
class DistributedElementwise(DistributedOperatorImplContainer): class DistributedElementwise(DistributedOperatorImplContainer):
...@@ -52,6 +55,74 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -52,6 +55,74 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = False self._backward_implemented = False
def calc_cost(self, op_role, dist_op, ctx, cluster):
"""Calculate the cost by the op role."""
cost = None
if int(op_role) == int(OpRole.Backward):
cost = self.calc_bwd_cost(dist_op, ctx, cluster)
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def calc_bwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
res = []
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
backward_op = dist_op.serial_op
op_type = backward_op.type
cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type],
ctx, processes, desc_mapping,
cluster)
res.append(cost_mapping)
main_block = backward_op.block
vars = main_block.vars
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
break
if need_gradient_allreduce:
for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related(
varname, main_block):
var_dim_mapping = dist_attr.get_input_dims_mapping(
varname)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True}
var_names = [varname + "@GRAD"]
build_dp_costs(res, dist_op, ctx, var_names, attrs,
parallel_axis, cluster)
return res
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
if not is_elementwise_op(op_desc.type()): if not is_elementwise_op(op_desc.type()):
......
...@@ -27,7 +27,12 @@ from paddle.fluid import core, unique_name ...@@ -27,7 +27,12 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from .dist_default import DistributedDefaultImpl0 from .dist_default import DistributedDefaultImpl0
from ..cost import FillConstantBatchSizeLikeOpCost
from ..cost import build_comp_desc_from_dist_op, build_dp_costs
from ..cost import build_comp_costs_from_descs
from ..cost import AllreduceSumOpCost
class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer): class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
...@@ -47,6 +52,29 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -47,6 +52,29 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def calc_cost(self, op_role, dist_op, ctx, cluster):
cost = None
if int(op_role) == int(OpRole.Backward):
raise ValueError(
"The fill_constant_batch_size_like has no grad op.")
else:
cost = self.calc_fwd_cost(dist_op, ctx, cluster)
assert cost is not None
return cost
def calc_fwd_cost(self, dist_op, ctx, cluster):
# calc comp op cost
desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op,
dist_context=ctx)
processes = dist_op.dist_attr.process_mesh.processes
op_type = dist_op.serial_op.type
cost_mapping = build_comp_costs_from_descs(
FillConstantBatchSizeLikeOpCost, ctx, processes, desc_mapping,
cluster)
res_cost = [cost_mapping]
return res_cost
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
return True return True
......
...@@ -55,4 +55,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -55,4 +55,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS}) py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS}) py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS}) py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
endif() endif()
...@@ -35,7 +35,6 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingOpCost ...@@ -35,7 +35,6 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingGradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import EmbeddingGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FillConstantBatchSizeLikeGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import GatherOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import GatherOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluGradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import GeluGradOpCost
...@@ -184,11 +183,6 @@ class TestCompOpCost(unittest.TestCase): ...@@ -184,11 +183,6 @@ class TestCompOpCost(unittest.TestCase):
self.assertTrue(op_cost.time >= 0) self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0) self.assertTrue(op_cost.memory >= 0)
op_cost = FillConstantBatchSizeLikeGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = GatherOpCost(cluster=cluster) op_cost = GatherOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0) self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0) self.assertTrue(op_cost.time >= 0)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import copy
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container, is_elementwise_op
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.backward import append_backward
paddle.enable_static()
def parallelizer(program_func, rank):
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.dist_context import DistributedContext
main_program, startup_program, loss = program_func()
# complete forward
dist_context = DistributedContext()
completer = Completer(dist_context)
completer.complete_forward_annotation(main_program)
dist_context.block_state.parse_forward_blocks(main_program)
# generate backward and complete backward
with paddle.static.program_guard(main_program, startup_program):
params_grads = append_backward(
loss, None, None, None, distop_context=dist_context.dist_op_context)
completer.complete_backward_annotation(main_program)
dist_context.block_state.parse_backward_blocks(main_program)
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
# generate opt and complete opt
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(optimizer).apply_gradients(params_grads)
completer.complete_update_annotation(main_program)
return main_program, dist_context
class TestDistOpCost(unittest.TestCase):
def test_dist_fill_constatnt_batch_size_like_op_cost(self):
def make_program():
main_program = paddle.static.Program()
start_program = paddle.static.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 8], dtype='float32')
x.stop_gradient = True
label = paddle.static.data(name="label",
shape=[4, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr()
linear = paddle.nn.Linear(8, 8, weight_attr=weight_attr)
linear_out = linear(x)
gelu_out = paddle.nn.functional.gelu(linear_out)
# default op with dp
tmp = paddle.static.nn.layer_norm(gelu_out)
error_cost = paddle.nn.functional.square_error_cost(tmp, label)
loss = paddle.mean(error_cost)
return main_program, start_program, loss
main_program, dist_context = parallelizer(make_program, 0)
ops = main_program.global_block().ops
cluster = Cluster()
cluster.gen_default_config_cluster(device_count=2)
for idx, op in enumerate(ops):
if op.type != "matmul_v2" and op.type != "matmul_v2_grad" and op.type != "sgd":
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
if is_elementwise_op(op.type):
container = get_distributed_operator_impl_container(
"elementwise")
else:
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, cluster)
self.assertTrue(dist_op_cost)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册