提交 f946aea1 编写于 作者: L lichenever

fix grpah mode loop sink bug in auto parallel

上级 976226f9
......@@ -34,8 +34,8 @@ namespace parallel {
#define OPERATOR_TO_OPERATOR_CONNECTOR "-"
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
#define DEFAULT_COST_MODEL_ALPHA 1.0
#define DEFAULT_COST_MODEL_BETA 65.0
#define DEFAULT_COST_MODEL_GAMMA 0.02
#define DEFAULT_COST_MODEL_BETA 260.0
#define DEFAULT_COST_MODEL_GAMMA 0.001
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0
......
......@@ -375,6 +375,10 @@ bool IsParallelCareNode(const CNodePtr& cnode) {
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
return false;
}
// get_next is not in the forward graph, we need mark the get_next as the forward node
if (prim->name() == GET_NEXT) {
return true;
}
if ((prim->name() == CAST)) {
if ((!attrs.count(STRATEGY)) && (cnode->operator_info() == nullptr)) {
return false;
......
......@@ -88,7 +88,7 @@ class _DatasetIter:
# times the batch dimension of tensors for run
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
def __iter__(self):
self.ind = 0
......
......@@ -80,9 +80,9 @@ def test_common_parameter():
_executor.compile(net, x, y, z, w, phase='train')
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
'Default/network-Net/MatMul-op9': [[1, 1], [1, 8]],
'Default/network-Net/Cast-op10': [[1, 8]],
'Default/network-Net/MatMul-op0': [[1, 1], [1, 8]],
'Default/network-Net/Cast-op11': [[1, 8]]}
expected_strategies = {'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op8': [[8, 1], [1, 1]],
'Default/network-Net/Cast-op7': [[1, 1]],
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
'Default/network-Net/Cast-op9': [[1, 1]]}
assert strategies == expected_strategies
......@@ -86,9 +86,9 @@ def test_two_matmul():
costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha")
assert costmodel_alpha == 1.0
costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta")
assert costmodel_beta == 65.0
assert costmodel_beta == 260.0
costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma")
assert costmodel_gamma == 0.02
assert costmodel_gamma == 0.001
costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold")
assert costmodel_communi_threshold == 2048.0
costmodel_communi_const = cost_model_context.get_cost_model_context("costmodel_communi_const")
......@@ -138,3 +138,4 @@ def test_two_matmul():
expected_strategies = {'Default/network-Net/MatMul-op2': [[16, 1], [1, 1]],
'Default/network-Net/MatMul-op3': [[16, 1], [1, 1]]}
assert strategies == expected_strategies
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册