未验证 提交 51ba2a0f 编写于 作者: C caozhou 提交者: GitHub

add op cost interface (#56803)

上级 d53972fd
......@@ -429,6 +429,7 @@ class Cluster:
# This property only be valid when the cluster consists of machines,
# which have the same number accelerators.
self._num_devices_per_machine = None
self._gpu_model = None
def gen_default_config_cluster(
self,
......@@ -451,6 +452,7 @@ class Cluster:
dcu_models = ["DCU"]
all_gpu_models = gpu_models + xpu_models + dcu_models
self._num_devices_per_machine = device_count
self._gpu_model = gpu_model
def _convert_to_type(gpu_model):
type = None
......
......@@ -22,6 +22,7 @@ from .base_cost import build_comp_desc_from_dist_op
from .base_cost import build_comm_desc_from_dist_op
from .base_cost import build_comm_costs_from_descs
from .base_cost import build_comp_costs_from_descs
from .base_cost import calc_time_by_cost_model
from .comp_op_cost import EmbeddingOpCost
from .comp_op_cost import EmbeddingGradOpCost
......
......@@ -19,7 +19,7 @@ import numpy as np
import paddle
from paddle.utils.flops import flops
from ..cluster import LinkType
from ..cluster import LinkType, get_default_cluster
from ..dist_tensor import DistributedTensor
from ..process_group import get_process_group
from ..utils import _get_comm_group, _get_idx_in_axis
......@@ -785,9 +785,12 @@ class CommOpCost(OpCost):
if self.op is not None:
vars = self.op.block.vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name = self.op.input("X")[0]
try:
var_name = self.op.input("X")[0]
except:
var_name = self.op.output("Out")[0]
var = get_var_with_recursion(
var_name, self.op.block, self.program
var_name, self.op.block, self.op.block.program
)
dtype = var.dtype
shape = var.shape
......@@ -838,7 +841,7 @@ class CommOpCost(OpCost):
if self.op_desc is not None:
self._group_ranks = self.op_desc["group_ranks"]
elif self.op is not None:
ring_id = self.op.attrs("ring_id")
ring_id = self.op.attr("ring_id")
process_group = get_process_group(ring_id)
if process_group is None:
raise ValueError(
......@@ -921,3 +924,57 @@ def calc_time_by_modeling(op=None, desc=None, cluster=None):
)
time = op_cost.calc_time()
return time
def calc_time_by_cost_model(op, cluster=None):
"""Calc op time by cost model and the unit is microsecond."""
if not isinstance(op, paddle.fluid.framework.Operator):
raise TypeError(
"OP must be paddle.fluid.framework.Operator, but got {}.".format(
type(op)
)
)
if not cluster:
cluster = get_default_cluster()
time = 0.0
op_type = op.type
# calc comp op time by flops
if op_type not in NON_COMP_TYPE:
attrs = op.all_attrs()
# build comp op inputs desc to calc flops.
# for example, a matmul op inputs desc will be {"X": [(1024, 1024)], "Y": [(1024, 1024)]}
inputs = {}
for input_name in op.input_names:
var_names = op.input(input_name)
inputs[input_name] = []
for var_name in var_names:
var = op.block._var_recursive(var_name)
inputs[input_name].append(var.shape)
# the time of grad operator is twice than its forward operator empirically
if "_grad" in op_type:
op_type = op_type[: len(op_type) - 5]
flops_count = 2 * flops(op_type, inputs, attrs)
else:
flops_count = flops(op_type, inputs, attrs)
if cluster._gpu_model == "V100":
time = flops_count * 2.9e-7 * 2.6
elif cluster._gpu_model == "A100":
time = flops_count * 2.9e-7
else:
raise ValueError(
"Only A100 and V100 gpu has been supported currently."
)
# calc comm op time by communication modeling formula
elif op_type in COMM_OP_TYPE:
op_cost = _g_op_cost_factory[op_type](
op=op, comm_context=CommContext(cluster)
)
time = op_cost.calc_time()
else:
raise ValueError(f"The {op_type} has not been supported now.")
return time
......@@ -162,6 +162,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner)
py_test_modules(test_dist_tensor MODULES test_dist_tensor)
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api)
py_test_modules(test_cost_interface MODULES test_cost_interface)
# End of unittests WITH single card WITHOUT timeout
endif()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn.functional as F
from paddle import nn, static, utils
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.static.cluster import Cluster
from paddle.distributed.auto_parallel.static.completion import Completer
from paddle.distributed.auto_parallel.static.cost import calc_time_by_cost_model
from paddle.distributed.auto_parallel.static.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.static.parallelizer import (
AutoParallelizer,
)
from paddle.distributed.auto_parallel.static.partitioner import Partitioner
from paddle.distributed.auto_parallel.static.reshard import Resharder
from paddle.distributed.fleet import auto
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh(
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], dim_names=["x", "y", "z"]
)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02,
):
super().__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
param = paddle.create_parameter([1024, 4096], paddle.float32)
auto.shard_tensor(param, PP_MESH_1, [None, "y"])
out = paddle.matmul(out, param)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(
train_program, start_program
), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32'
)
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32'
)
auto.shard_tensor(input, PP_MESH_0, ["x", None])
auto.shard_tensor(label, PP_MESH_1, ["x", None])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02,
)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.process_mesh = _global_process_mesh
loss, train_program, startup_program = mlp_forward(
train_program, startup_program
)
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.optimizer.Adam()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program
)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward(
complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None,
)
# logical partition
partitioner = Partitioner(dist_context, rank_id)
(
auto_parallel_main_prog,
auto_parallel_startup_prog,
dist_params_grads,
) = partitioner.partition(
complete_train_program, startup_program, params_grads
)
partitioned_optimize_ops = parallelizer._apply_optimize(
auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
)
return (
auto_parallel_main_prog,
auto_parallel_startup_prog,
dist_params_grads,
)
class TestCostInterface(unittest.TestCase):
def test_cost_interface(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 2
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id
)
resharder = Resharder(
dist_main_prog,
dist_startup_prog,
rank_id,
dist_context,
dist_params_grads,
)
resharder.reshard()
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
for op in dist_main_prog.global_block().ops:
time = calc_time_by_cost_model(op, cluster)
assert time > -1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册