未验证 提交 c1c9368f 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Update cost model (#40457)

* refactor cost model
上级 1b491818
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import OP_COST_FACTORY
from .base_cost import Cost
from .comm_op_cost import AllreduceSumCost
from .comp_op_cost import MatmulV2OpCost
from .tensor_cost import TensorCost
from .estimate_cost import CostEstimator
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from collections import OrderedDict
import paddle
COMM_OP_TYPE = [
"send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum"
]
NON_COMP_TYPE = ["while"] + COMM_OP_TYPE
OP_COST_FACTORY = {}
def _parse_op_to_desc(op, dist_context=None):
desc = {}
desc["op"] = op.type
vars = op.block.vars
input_desc = OrderedDict()
for input_name in op.input_names:
var_name_list = op.input(input_name)
var_desc = []
for var_name in var_name_list:
var = vars[var_name]
shape = None
if dist_context is not None:
dist_tensor = dist_context.get_dist_tensor_for_program(var)
shape = dist_tensor.local_sizes()
else:
shape = var.shape
assert shape is not None
var_desc.append((var.dtype, shape))
input_desc[input_name] = var_desc
desc["inputs"] = input_desc
output_desc = OrderedDict()
for out_name in op.output_names:
var_name_list = op.output(out_name)
var_desc = []
for var_name in var_name_list:
var = vars[var_name]
shape = None
if dist_context is not None:
dist_tensor = dist_context.get_dist_tensor_for_program(var)
shape = dist_tensor.local_sizes()
else:
shape = var.shape
assert shape is not None
var_desc.append((var.dtype, shape))
output_desc[out_name] = var_desc
desc["outputs"] = output_desc
attr_desc = op.all_attrs
desc["attrs"] = attr_desc
return desc
def parse_to_desc(op=None, dist_op=None, dist_context=None):
desc = None
if op is None and dist_op is not None and dist_context is not None:
desc = _parse_op_to_desc(
op=dist_op.serial_op, dist_context=dist_context)
elif op is not None and dist_op is None and dist_context is None:
desc = _parse_op_to_desc(op)
return desc
def parse_desc_to_str(desc):
def _parse_dtype(dtype):
dtype_str = ""
if dtype == paddle.float32:
dtype_str = "float32"
elif dtype == paddle.float16:
dtype_str = "float16"
elif dtype == paddle.int32:
dtype_str = "int32"
elif dtype == paddle.int64:
dtype_str = "int64"
elif dtype == paddle.unit8:
dtype_str = "unit8"
else:
raise TypeError("Unsupported dtype {}".format(dtype))
return dtype_str
assert isinstance(desc, dict)
desc_str_list = []
desc_str = None
dtype_str_list = []
dims_list = []
shape_list = []
desc_str_list.append(desc["op"])
inputs = desc["inputs"]
for key, item in inputs.items():
for dtype, shape in item:
dtype_str_list.append(_parse_dtype(dtype))
shape_list += list(shape)
dims = len(shape)
dims_list.append(dims)
dtype_str = "*".join(dtype_str_list)
dims_list = [str(item) for item in dims_list]
dims_str = "*".join(dims_list)
shape_list = [str(item) for item in shape_list]
shape_str = "[" + ",".join(shape_list) + "]"
desc_str_list += [dtype_str, dims_str, shape_str]
desc_str = "_".join(desc_str_list)
return desc_str
class CommContext:
_instance = None
_has_instance = False
def __init__(self, cluster):
if CommContext._has_instance:
return
self.cluster = cluster
self._alpha_base_ring = 8.4
self._alpha_base_tree = 0
self._alpha_inter = None
self._alpha_intra
self._beta = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls, *args, **kwargs)
_has_instance = True
return cls._instance
@property
def alpha_inter(self):
if self._alpha_inter is None:
if cluster.alpha.inter == "NVL":
self._alpha_inter = 3.4
elif cluster.alpha.inter == "PHB":
self._alpha_inter = 5.7
return self._alpha_inter
@property
def alpha_intra(self):
if self._alpha_intra is None:
if cluster.alpha.intra == "NVL":
self._alpha_intra = 28
elif cluster.alpha.intra == "PHB":
self._alpha_intra = 28
return self._alpha_intra
@property
def alpha_base_ring(self):
return self._alpha_base_ring
@property
def alpha_base_tree(self):
return self._alpha_base_tree
def get_beta(self, ranks):
key = ','.join(map(str, sorted(ranks)))
max_beta = None
if key in self._beta.keys:
max_beta = self._beta[key]
else:
for i in range(len(ranks)):
for j in range(i + 1, len(ranks)):
if min_beta == None:
min_beta = cluster.get_beta(ranks[i], ranks[j])
else:
beta = cluster.get_beta(ranks[i], ranks[j])
if beta > max_beta:
max_beta = beta
self._beta[key] = max_beta
return max_beta
class Cost:
def __init__(self, time=0, memory=0, flops=0):
self.time = time
self.memory = memory
self.flops = flops
def _check_time(self, val):
assert val >= 0, "Time must be greater than or equal to 0."
def _check_memory(self, val):
assert isinstance(
val, int) and val >= 0, "Memory must be int and greater than 0."
def _check_flops(self, val):
assert isinstance(
val, int) and val >= 0, "FLOPs must be int and greater than 0."
@property
def time(self):
return self._time
@time.setter
def time(self, val):
self._check_time(val)
self._time = val
@property
def memory(self):
return self._memory
@memory.setter
def memory(self, val):
self._check_memory(val)
self._memory = val
@property
def flops(self):
return self._flops
@flops.setter
def flops(self, val):
self._check_flops(val)
self._flops = val
def __add__(self, rhs):
assert isinstance(rhs, Cost)
time = self.time + rhs.time
memory = self.memory + rhs.memory
flops = self.flops + rhs.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
return Cost(time, memory, flops)
def __sub__(self, rhs):
assert isinstance(rhs, Cost)
time = self.time - rhs.time
memory = self.memory - rhs.memory
flops = self.flops - rhs.flops
assert (time >= 0 and memory >= 0 and flops >= 0)
return Cost(time, memory, flops)
class OpCost:
def __init__(self, op=None, op_desc=None):
assert (op is not None and op_desc is None) or (op is None and
op_desc is not None)
self._op = op
self._op_desc = op_desc
self._cost = self.calc_cost()
@property
def op(self):
return self._op
@property
def op_desc(self):
return self._op_desc
@property
def cost(self):
return self._cost
def calc_time(self):
return 0
def calc_memory(self):
return 0
def calc_flops(self):
return 0
def calc_cost(self):
time = self.calc_time()
memory = self.calc_memory()
flops = self.calc_flops()
cost = Cost(time, memory, flops)
return cost
class CommOpCost(OpCost):
OP_TYPE = "COMM"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(CommOpCost, self).__init__(op=op, op_desc=op_desc)
self._check_comm_op_type()
self._comm_context = comm_context
@property
def comm_context(self):
return self._comm_context
@classmethod
def _check_comm_op_type(cls):
if cls.OP_TYPE != "COMM":
if cls.OP_TYPE not in COMM_OP_TYPE:
raise TypeError("Please Check op type in {}, but got {}.".
format(COMM_OP_TYPE, cls.OP_TYPE))
class CompOpCost(OpCost):
OP_TYPE = "COMP"
def __init__(self, op=None, op_desc=None, cluster=None):
super(CompOpCost, self).__init__(op=op, op_desc=op_desc)
self._check_comp_op_type()
self.cluster = cluster
@classmethod
def _check_comp_op_type(cls):
if cls.OP_TYPE != "COMP":
if cls.OP_TYPE in NON_COMP_TYPE:
raise TypeError("Please Check op type not in {}, but got {}.".
format(NON_COMP_TYPE, cls.OP_TYPE))
def register_op_cost(cls):
op_type = cls.OP_TYPE
def register(op_type):
OP_COST_FACTORY[op_type] = cls
return register(op_type)
def calc_time_from_model(op=None, desc=None, cluster=None, comm_context=None):
op_type = op.type if op is not None else desc["op"]
if op_type in COMM_OP_TYPE:
op_cost = OP_COST_FACTORY[op_type](op=op,
op_desc=desc,
comm_context=comm_context)
elif op_type not in NON_COMP_TYPE:
op_cost = OP_COST_FACTORY[op_type](op=op, op_desc=desc, cluster=cluster)
time = op_cost.calc_time()
return time
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import register_op_cost, CommOpCost, OP_COST_FACTORY
@register_op_cost
class AllreduceSumCost(CommOpCost):
OP_TYPE = "c_allreduce_sum"
def __init__(self, op=None, op_desc=None, comm_context=None):
super(OP_COST_FACTORY["c_allreduce_sum"], self).__init__(
op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
# NOTE: The actual formula will be filled in the future.
return 0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .base_cost import Cost, register_op_cost, CompOpCost, OP_COST_FACTORY
@register_op_cost
class MatmulV2OpCost(CompOpCost):
OP_TYPE = "matmul_v2"
def __init__(self, op=None, op_desc=None, cluster=None):
super(OP_COST_FACTORY["matmul_v2"], self).__init__(
op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function needs to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
class CostEstimator:
def __init__(self,
program,
cluster=None,
dist_context=None,
mode="modeling"):
self._program = program
self._cluster = cluster
self._dist_context = dist_context
self._check_mode(mode)
self._mode = mode
self._global_cost = None
self._local_cost = {}
@property
def program(self):
return self._program
@property
def dist_context(self):
return self._dist_context
@property
def cluster(self):
return self._cluster
@property
def mode(self):
return self._mode
@property
def global_cost(self):
return self._global_cost
@property
def local_cost(self):
return self._local_cost
def get_op_cost(self):
return 0
def get_tensor_cost(self):
return 0
def get_global_cost(self):
return 0
def get_local_cost(self, rank=None):
return 0
def _check_mode(self, mode):
if mode not in ["modeling", "profiling"]:
raise ValueError(
"Just support modeling and profiling, but got {}".format(mode))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from functools import reduce
import paddle
from paddle.fluid.framework import Variable
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from .base_cost import Cost
class TensorCost:
def __init__(self, tensor=None, dist_tensor=None, shape=None, dtype=None):
self._check_args(tensor, dist_tensor, shape, dtype)
self._tensor = tensor
self._dist_tensor = dist_tensor
self._shape = shape
self._dtype = dtype
self._cost = self.calc_cost()
@property
def tensor(self):
return self._tensor
@property
def dist_tensor(self):
return self._dist_tensor
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
def _check_args(self, tensor, dist_tensor, shape, dtype):
if tensor is not None:
assert (shape is None and dist_tensor is None and dtype is None)
if not isinstance(tensor, Variable):
raise TypeError(
"Please check tensor type is Variable, but got {}".format(
type(tensor)))
elif dist_tensor is not None:
assert (tensor is None and shape is None)
if not isinstance(dist_tensor, DistributedTensor):
raise TypeError(
"Please check dist_tensor type is DistributedTensor, but got {}".
format(type(dist_tensor)))
elif shape is not None:
assert (tensor is None and dist_tensor is None and
dtype is not None)
if not isinstance(shape, (list, set)):
raise TypeError(
"Please check shape type is list or set, but got {}".format(
type(shape)))
elif dtype is not None:
assert (tensor is None and dist_tensor is None and
shape is not None)
@property
def cost(self):
return self._cost
def calc_cost(self):
dtype = None
shape = None
if self.dist_tensor:
shape = self.dist_tensor.local_sizes()
dtype = self.dist_tensor.serial_tensor.dtype
elif self.tensor:
shape = self.tensor.shape
dtype = self.tensor.dtype
elif self.shape and self.dtype:
shape = self.shape
dtype = self.dtype
total_count = reduce(lambda x, y: x * y, shape)
if dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4
elif node.dtype == paddle.int64:
dtype_factor = 8
elif node.dtype == paddle.uint8:
dtype_factor = 1
else:
dtype_factor = 2
memory = total_count * dtype_factor
assert memory >= 0
cost = Cost(memory=memory)
return cost
...@@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -17,4 +17,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS})
py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS}) py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS})
py_test_modules(test_trial MODULES test_trial ENVS ${dist_ENVS}) py_test_modules(test_trial MODULES test_trial ENVS ${dist_ENVS})
py_test_modules(test_new_cost_model MODULES test_new_cost_model ENVS ${dist_ENVS})
endif() endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.distributed.auto_parallel.cost as cost_model
from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc
from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_from_model
paddle.enable_static()
def check_cost(cost):
if cost.memory >= 0 and cost.flops >= 0 and cost.time >= 0:
return True
return False
class TestCost(unittest.TestCase):
def test_base_cost(self):
cost = cost_model.Cost(memory=100, flops=200, time=0.5)
self.assertTrue(check_cost(cost))
def test_comp_cost(self):
x = paddle.static.data(name="x", shape=[20, 20], dtype='float32')
y = paddle.static.data(name="y", shape=[20, 20], dtype='float32')
z = paddle.matmul(x, y)
matmul_v2_op = None
ops = paddle.static.default_main_program().global_block().ops
for op in ops:
if op.type == "matmul_v2":
matmul_v2_op = op
break
matmul_v2_cost = cost_model.OP_COST_FACTORY["matmul_v2"](
op=matmul_v2_op)
desc = parse_to_desc(op=matmul_v2_op)
desc_str = parse_desc_to_str(desc)
self.assertIsNotNone(desc_str)
self.assertTrue(check_cost(matmul_v2_cost.cost))
time = calc_time_from_model(op=matmul_v2_op)
self.assertEqual(time, matmul_v2_cost.cost.time)
tensor_cost = cost_model.TensorCost(tensor=x)
# check memory
self.assertEqual(tensor_cost.cost.memory, 1600)
def test_comm_cost(self):
desc = {}
desc["op"] = "c_allreduce_sum"
desc["inputs"] = {"X": [([100, 200], paddle.float32)]}
allreduce_cost = cost_model.OP_COST_FACTORY["c_allreduce_sum"](
op_desc=desc)
self.assertTrue(check_cost(allreduce_cost.cost))
def test_cost_estimator(self):
train_program = paddle.static.Program()
cost_estimator = cost_model.CostEstimator(train_program)
self.assertIsNotNone(cost_estimator)
if __name__ == "__main__":
unittest.main()
...@@ -307,6 +307,7 @@ packages=['paddle', ...@@ -307,6 +307,7 @@ packages=['paddle',
'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators', 'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.tuner', 'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.auto_parallel.cost',
'paddle.distributed.passes', 'paddle.distributed.passes',
'paddle.framework', 'paddle.framework',
'paddle.jit', 'paddle.jit',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册