未验证 提交 e379455a 编写于 作者: C caozhou 提交者: GitHub

【Auto Parallel】update base cost (#44095)

* update base cost

* update unittest of cost model

* add unittest
上级 3333a439
......@@ -17,8 +17,12 @@ from functools import reduce
import paddle
from ..cluster import LinkType
from ..utils import _get_comm_group, _get_corresponding_rank
from ..process_group import get_process_group
from ..cluster import LinkType
from ..dist_tensor import DistributedTensor
from ..utils import _get_idx_in_axis
from ..dist_tensor import DistributedTensor
COMM_OP_TYPE = [
"send_v2", "recv_v2", "c_broadcast", "c_allgather", "c_allreduce_sum",
......@@ -28,33 +32,22 @@ NON_COMP_TYPE = ["while"] + COMM_OP_TYPE
_g_op_cost_factory = {}
def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
desc = {}
desc["op"] = op_type
desc["group_ranks"] = group_ranks
desc["inputs"] = {"X": [(dtype, shape)]}
if attrs is not None:
desc["attrs"] = attrs
return desc
def build_comp_desc_from_op(op):
"""Build the description of computation op."""
# NOTE: The desc is for serial op.
from ..reshard import get_var_with_recursion
def _parse_op_to_desc(op, dist_context=None):
desc = {}
desc["op"] = op.type
# The desc of concat op is {"op": "concat", "inputs": {"X": [(paddle.float32, [20, 20]), (paddle.float32, [20, 20])]}, "outputs": {"Out": [(paddle.float32, [20, 40])], "attrs": {"axis": -1}}}
vars = op.block.vars
desc["op"] = op.type
input_desc = OrderedDict()
for input_name in op.input_names:
var_name_list = op.input(input_name)
var_desc = []
for var_name in var_name_list:
var = vars[var_name]
shape = None
if dist_context is not None:
dist_tensor = dist_context.get_dist_tensor_for_program(var)
shape = dist_tensor.local_sizes()
else:
var = get_var_with_recursion(var_name, op.block, op.block.program)
shape = var.shape
assert shape is not None
var_desc.append((var.dtype, shape))
input_desc[input_name] = var_desc
desc["inputs"] = input_desc
......@@ -64,14 +57,8 @@ def _parse_op_to_desc(op, dist_context=None):
var_name_list = op.output(out_name)
var_desc = []
for var_name in var_name_list:
var = vars[var_name]
shape = None
if dist_context is not None:
dist_tensor = dist_context.get_dist_tensor_for_program(var)
shape = dist_tensor.local_sizes()
else:
var = get_var_with_recursion(var_name, op.block, op.block.program)
shape = var.shape
assert shape is not None
var_desc.append((var.dtype, shape))
output_desc[out_name] = var_desc
desc["outputs"] = output_desc
......@@ -82,19 +69,101 @@ def _parse_op_to_desc(op, dist_context=None):
return desc
def parse_to_desc(op=None, dist_op=None, dist_context=None):
desc = None
if op is None and dist_op is not None and dist_context is not None:
desc = _parse_op_to_desc(op=dist_op.serial_op,
dist_context=dist_context)
elif op is not None and dist_op is None and dist_context is None:
desc = _parse_op_to_desc(op)
def build_comp_desc_from_dist_op(dist_op, dist_context):
"""Build descriptions of computation op distributed on the processes."""
from ..reshard import get_var_with_recursion
return desc
op_descs = {}
op = dist_op.serial_op
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes
for process in processes:
desc = {}
desc["op"] = op.type
attr_desc = op.all_attrs()
# NOTE: The attrs of desc is replica of serial op, there may be a bug if shape need to be partitioned involved in attrs.
desc["attrs"] = attr_desc
input_desc = OrderedDict()
output_desc = OrderedDict()
# Get partitioned shape of input
for input_name in op.input_names:
var_name_list = op.input(input_name)
var_desc = []
for var_name in var_name_list:
var = get_var_with_recursion(var_name, op.block,
op.block.program)
# Use op input_dims_mapping
dims_mapping = dist_attr.get_input_dims_mapping(var_name)
global_sizes = var.shape
# NOTE: When support uneven partition, the shard_sizes will be got from dist_attr.
shard_sizes = None
topology = process_mesh.topology
shape = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, process,
shard_sizes)
var_desc.append((var.dtype, shape))
# For special op such as embedding and its grad op
if op.type == "c_embedding" or op.type == "lookup_table_v2" or op.type == "c_embedding_grad" or op.type == "lookup_table_v2_grad":
if input_name == "W":
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
op.input(input_name)[0])[0]
relative_idx = _get_idx_in_axis(
processes, dist_attr.process_mesh.topology,
embedding_row_dim_mapping, process)
per_part_size = shape[0]
relative_idx = relative_idx * per_part_size
desc["attrs"]["start_index"] = relative_idx
def parse_desc_to_str(desc):
input_desc[input_name] = var_desc
desc["inputs"] = input_desc
for out_name in op.output_names:
var_name_list = op.output(out_name)
var_desc = []
for var_name in var_name_list:
# Use op output_dims_mapping
var = get_var_with_recursion(var_name, op.block,
op.block.program)
dist_attr = dist_op.dist_attr
dims_mapping = dist_attr.get_output_dims_mapping(var_name)
process_mesh = dist_attr.process_mesh
global_sizes = var.shape
shard_sizes = None
processes = process_mesh.processes
topology = process_mesh.topology
shape = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, process,
shard_sizes)
var_desc.append((var.dtype, shape))
# For special op such as fill_constant_batch_size_like
if op.type == "fill_constant_batch_size_like":
# Modify shape attr according to how output are partitioned
out_name = var_name_list[0]
dims_mapping = dist_attr.get_output_dims_mapping(out_name)
process_mesh_shape = dist_attr.process_mesh.topology
shape_list = op.attr("shape")
# Modify target shape
for idx, axis in enumerate(dims_mapping):
if axis >= 0:
shape_list[idx] = shape_list[
idx] // process_mesh_shape[axis]
desc["attrs"]["shape"] = shape_list
output_desc[out_name] = var_desc
desc["outputs"] = output_desc
op_descs[process] = desc
return op_descs
def build_comp_desc_str_for_predict(desc):
# NOTE: The description format may change in the future.
def _parse_dtype(dtype):
dtype_str = ""
if dtype == paddle.float32:
......@@ -135,8 +204,208 @@ def parse_desc_to_str(desc):
shape_str = "[" + ",".join(shape_list) + "]"
desc_str_list += [dtype_str, dims_str, shape_str]
desc_str = "_".join(desc_str_list)
attrs = desc["attrs"]
parse_result = (desc_str, attrs)
return parse_result
def build_comm_desc_from_dist_op(op_type,
dist_op,
ctx,
var_names,
attrs=None,
parallel_axis=None,
group_ranks=None):
"""Build descriptions of communication op distributed on the processes."""
from ..reshard import get_var_with_recursion
specific_op_type = []
dist_attr = dist_op.dist_attr
assert dist_attr, "Dist attr must not be None."
process_mesh = dist_attr.process_mesh
assert process_mesh, "Process mesh must not be None."
processes = process_mesh.processes
op_descs = {}
for process in processes:
rank_id = process
desc = {}
desc["op"] = op_type
op_attrs = None
comm_group_ranks = None
if op_type not in specific_op_type:
serial_op = dist_op.serial_op
input_list = []
# The var_names usually contain just one item.
for var_name in var_names:
dist_attr = dist_op.dist_attr
has_found = False
# Find var_name in serial op input or output
for name in dist_op.serial_op.input_arg_names:
# If a tensor is the input of multi ops, sum the grad of all ops, so the name will be varname@RENAME@block@0 and so on.
if var_name in name:
var_name = name
has_found = True
break
if not has_found:
for name in dist_op.serial_op.output_arg_names:
if var_name in name:
var_name = name
has_found = True
break
assert has_found
var = get_var_with_recursion(var_name, serial_op.block,
serial_op.block.program)
dims_mapping = dist_attr.get_input_dims_mapping(
var_name
) if var_name in dist_op.serial_op.input_arg_names else dist_attr.get_output_dims_mapping(
var_name)
global_sizes = var.shape
shard_sizes = None
topology = process_mesh.topology
shape = DistributedTensor.get_local_sizes(
global_sizes, dims_mapping, topology, processes, process,
shard_sizes)
input_list.append((var.dtype, shape))
# NOTE: The input_name of comm ops used usually is X.
desc["inputs"] = {"X": input_list}
# Get comm group by parallel_axis or the given group_ranks.
if parallel_axis is not None:
process_mesh_shape = process_mesh.topology
process_mesh_group = process_mesh.processes
comm_group_ranks = _get_comm_group(process_mesh_group,
process_mesh_shape,
parallel_axis, rank_id)
elif group_ranks is not None:
comm_group_ranks = group_ranks
else:
raise ValueError(
"The parallel_axis and group_ranks can not be None in the same."
)
if attrs is not None:
assert isinstance(attrs, dict)
op_attrs = attrs
else:
op_attrs = {}
desc["attrs"] = op_attrs
desc["group_ranks"] = comm_group_ranks
op_descs[rank_id] = desc
return op_descs
return desc_str
def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
"""Build a comm desc directly."""
desc = {}
desc["op"] = op_type
desc["group_ranks"] = group_ranks
desc["inputs"] = {"X": [(dtype, shape)]}
desc["attrs"] = attrs
return desc
def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
"""Build comm costs by descriptions"""
comm_context = CommContext(cluster)
group_ranks_list = []
comm_op_cost_list = []
for process in processes:
desc = descs[process]
group_ranks = desc["group_ranks"]
if group_ranks not in group_ranks_list:
group_ranks_list.append(group_ranks)
comm_op_cost = op_cost_class(op_desc=desc,
comm_context=comm_context)
comm_op_cost_list.append(comm_op_cost)
return comm_op_cost_list
def build_comp_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
"""Build comp costs by descriptions."""
costs = {}
for process in processes:
costs[process] = op_cost_class(op_desc=descs[process], cluster=cluster)
return costs
def build_dp_costs(result, dist_op, ctx, var_names, attrs, parallel_axis,
cluster):
"""DP cost contains a allreduce_sum op cost and a scale op cost"""
# The costs will be appended in the given result.
from ..reshard import get_var_with_recursion
dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
assert len(var_names) == 1
vars = dist_op.serial_op.block.vars
var_name = var_names[0]
has_found = False
for name in dist_op.serial_op.input_arg_names:
if var_name in name:
var_name = name
has_found = True
break
if not has_found:
for name in dist_op.serial_op.output_arg_names:
if var_name in name:
var_name = name
has_found = True
break
if not has_found:
return
c_allreduce_sum_descs = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
ctx,
var_names,
attrs=attrs,
parallel_axis=parallel_axis)
comm_cost_list = build_comm_costs_from_descs(
_g_op_cost_factory["c_allreduce_sum"], ctx, processes,
c_allreduce_sum_descs, cluster)
result.append(comm_cost_list)
# The scale op just on the group_ranks
for comm_cost in comm_cost_list:
group_ranks = comm_cost.group_ranks
dp_degree = len(group_ranks)
scale_costs = {}
op_type = "scale"
for rank in group_ranks:
desc = {}
desc["op"] = op_type
desc["inputs"] = {}
dims_mapping = dist_attr.get_input_dims_mapping(
var_name) if dist_attr.get_input_dims_mapping(
var_name
) is not None else dist_attr.get_output_dims_mapping(var_name)
var = get_var_with_recursion(var_name, dist_op.serial_op.block,
dist_op.serial_op.block.program)
global_sizes = var.shape
shard_sizes = None
topology = process_mesh.topology
shape = DistributedTensor.get_local_sizes(global_sizes,
dims_mapping, topology,
processes, rank,
shard_sizes)
desc["inputs"]["X"] = [(var.dtype, shape)]
attrs = {"scale": 1.0 / dp_degree}
desc["attrs"] = attrs
scale_op_cost = _g_op_cost_factory["scale"](op_desc=desc,
cluster=cluster)
scale_costs[rank] = scale_op_cost
result.append(scale_costs)
class CommContext:
......@@ -174,6 +443,8 @@ class CommContext:
# set default
self.base_ring = 8.4
self.base_tree = 0.
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
self.intra_ring = 3.4
self.intra_tree = 28
......@@ -441,6 +712,8 @@ class CommOpCost(OpCost):
@property
def comm_count(self):
from ..reshard import get_var_with_recursion
if self._comm_count is None:
dtype = None
shape = None
......@@ -448,7 +721,8 @@ class CommOpCost(OpCost):
vars = self.op.block.vars
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overrided
var_name = self.op.input("X")[0]
var = vars[var_name]
var = get_var_with_recursion(var_name, self.op.block,
self.program)
dtype = var.dtype
shape = var.shape
elif self.op_desc is not None:
......@@ -464,9 +738,10 @@ class CommOpCost(OpCost):
factor = 1
elif dtype == paddle.float16:
factor = 2
elif dtype == paddle.bool:
factor = 8
else:
raise TypeError(
"This dtype {} is not supported now".format(dtype))
raise ValueError("Unsupported comm dtype {}".format(dtype))
comm_count = reduce(lambda x, y: x * y, shape) * factor
self._comm_count = comm_count
......
......@@ -51,6 +51,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
py_test_modules(test_base_cost MODULES test_base_cost ENVS ${dist_ENVS})
py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS})
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
import json
import tempfile
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CommContext
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_dist_op
from paddle.distributed.auto_parallel.cost.base_cost import build_comm_desc_from_dist_op
from paddle.distributed.auto_parallel.cost.base_cost import build_comm_costs_from_descs
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_costs_from_descs
from paddle.distributed.auto_parallel.cost.base_cost import build_dp_costs
from paddle.distributed.auto_parallel.cost import AllreduceSumOpCost
from paddle.distributed.auto_parallel.cost import _g_op_cost_factory
from test_cluster import cluster_json
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(d_model,
dim_feedforward,
weight_attr,
bias_attr=bias_attr)
self.linear1 = nn.Linear(dim_feedforward,
d_model,
weight_attr,
bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
def mlp_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(name="input",
shape=[batch_size, hidden_size],
dtype='float32')
label = static.data(name="label",
shape=[batch_size, 1],
dtype='float32')
fill_constant_out = paddle.fluid.layers.fill_constant_batch_size_like(
input=input, shape=[batch_size], value=1, dtype="int32")
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out)
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(embedding_out)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss, train_program, start_program
def get_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.process_mesh = _global_process_mesh
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
fleet._user_defined_strategy = fleet.DistributedStrategy()
fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
parallelizer = AutoParallelizer(fleet)
parallelizer._dist_context = dist_context
# serial forward & backward completion
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward(complete_train_program,
startup_program,
loss,
parameter_list=None,
no_grad_set=None,
callbacks=None)
return train_program, startup_program, params_grads
class TestBaseCost(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_base_cost(self):
# Build cluster
cluster_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file(cluster_json_path)
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 2
train_program, startup_program, params_grads = get_prog(
train_program, startup_program, dist_context, rank_id)
for op in train_program.global_block().ops:
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op:
processes = dist_op.dist_attr.process_mesh.processes
comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context)
self.assertTrue(isinstance(comp_descs, dict) and comp_descs)
var_names = None
if op.input_arg_names:
var_names = op.input_arg_names[0]
comm_descs = build_comm_desc_from_dist_op("c_allreduce_sum",
dist_op,
dist_context,
var_names,
attrs=None,
parallel_axis=0,
group_ranks=None)
self.assertTrue(isinstance(comm_descs, dict) and comm_descs)
comm_descs = build_comm_desc_from_dist_op(
"c_allreduce_sum",
dist_op,
dist_context,
var_names,
attrs=None,
parallel_axis=None,
group_ranks=processes)
self.assertTrue(isinstance(comm_descs, dict) and comm_descs)
comm_costs = build_comm_costs_from_descs(
AllreduceSumOpCost, dist_context, processes, comm_descs,
cluster)
self.assertTrue(comm_costs)
comp_costs = build_comp_costs_from_descs(
_g_op_cost_factory[op.type], dist_context, processes,
comp_descs, cluster)
self.assertTrue(comp_costs)
result = []
build_dp_costs(result, dist_op, dist_context, var_names[0],
None, 0, cluster)
self.assertTrue(result)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__":
unittest.main()
......@@ -2018,6 +2018,10 @@ class TestCluster(unittest.TestCase):
self.assertTrue(devices == [5, 6, 7, 10])
self.assertTrue(involved_machine_count == 2)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__":
unittest.main()
......@@ -154,6 +154,10 @@ class TestCommOpCost(unittest.TestCase):
comm_context=comm_context)
self.assertTrue(recv_op_cost.time > 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__":
unittest.main()
......@@ -19,8 +19,8 @@ import tempfile
import paddle
import paddle.distributed.auto_parallel.cost as cost_model
from paddle.distributed.auto_parallel.cost.base_cost import parse_to_desc
from paddle.distributed.auto_parallel.cost.base_cost import parse_desc_to_str
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_op
from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_str_for_predict
from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.cost import CommContext
......@@ -60,8 +60,8 @@ class TestCost(unittest.TestCase):
break
matmul_v2_cost = cost_model._g_op_cost_factory["matmul_v2"](
op=matmul_v2_op)
desc = parse_to_desc(op=matmul_v2_op)
desc_str = parse_desc_to_str(desc)
desc = build_comp_desc_from_op(op=matmul_v2_op)
desc_str = build_comp_desc_str_for_predict(desc)
self.assertIsNotNone(desc_str)
self.assertTrue(check_cost(matmul_v2_cost.cost))
time = calc_time_by_modeling(op=matmul_v2_op)
......@@ -92,11 +92,29 @@ class TestCost(unittest.TestCase):
op_desc=desc, comm_context=CommContext(cluster))
self.assertTrue(check_cost(allreduce_cost.cost))
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
def test_cost_estimator(self):
# Build cluster
cluster_json_path = os.path.join(self.temp_dir.name,
"auto_parallel_cluster.json")
cluster_json_object = json.loads(cluster_json)
with open(cluster_json_path, "w") as cluster_json_file:
json.dump(cluster_json_object, cluster_json_file)
cluster = Cluster()
cluster.build_from_file(cluster_json_path)
train_program = paddle.static.Program()
cost_estimator = cost_model.CostEstimator(train_program)
cost_estimator = cost_model.CostEstimator(train_program,
cluster=cluster)
self.assertIsNotNone(cost_estimator)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册