提交 42f12412 编写于 作者: X Xiaoda Zhang

remove 'multi-subgraphs' to internal

上级 90fa4c9d
......@@ -17,7 +17,5 @@ This interface is ONLY used in Auto-parallel procedure.
"""
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs",
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
......@@ -589,7 +589,7 @@ def reset_cost_model_context():
"""Reset cost model context attributes."""
cost_model_context().reset_cost_model()
def set_multi_subgraphs(multi_subgraph=True):
def _set_multi_subgraphs(multi_subgraph=True):
"""
Set the flag of ANF graph containing multiple subgraphs.
......@@ -598,7 +598,7 @@ def set_multi_subgraphs(multi_subgraph=True):
"""
cost_model_context().set_multi_subgraphs(multi_subgraph)
def get_multi_subgraphs():
def _get_multi_subgraphs():
"""
Get the flag of ANF graph containing multiple subgraphs.
"""
......
......@@ -32,6 +32,7 @@ from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor
from ..parallel._cost_model_context import _set_multi_subgraphs
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp
......@@ -166,6 +167,9 @@ class Model:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
return network
def _build_eval_network(self, metrics, eval_network, eval_indexes):
......@@ -190,6 +194,9 @@ class Model:
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if self._optimizer:
self._eval_network = _VirtualDatasetCell(self._eval_network)
if self._optimizer is None:
# In this case, multiple optimizer(s) is supposed to be included in 'self._network'
_set_multi_subgraphs()
self._eval_network.set_auto_parallel()
def _build_predict_network(self):
......@@ -197,6 +204,7 @@ class Model:
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
_set_multi_subgraphs()
self._predict_network.set_auto_parallel()
def _clear_metrics(self):
......
......@@ -22,7 +22,6 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......@@ -145,7 +144,6 @@ if __name__ == "__main__":
device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
set_multi_subgraphs()
init()
if wide_deep_config.host_device_mix == 1:
context.set_auto_parallel_context(
......
......@@ -21,7 +21,6 @@ from mindspore import Model, context
from mindspore.train.callback import TimeMonitor
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......@@ -33,7 +32,6 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
set_multi_subgraphs()
init()
......
......@@ -17,13 +17,13 @@ import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import context
from mindspore import context, Model
from mindspore.common.api import _executor
from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.parallel import set_multi_subgraphs
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
......@@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
def test_double_subgraphs():
set_multi_subgraphs()
_set_multi_subgraphs()
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
......@@ -120,3 +120,50 @@ def test_double_subgraphs():
'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]}
assert strategies == expected_strategies
class DatasetLenet():
def __init__(self, predict, label, length=3):
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict
def reset(self):
self.index = 0
def get_dataset_size(self):
return 32
def get_repeat_count(self):
return 1
def create_tuple_iterator(self):
return self
def test_double_subgraphs_train():
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=1, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net = TrainStepWarp(NetWithLoss(Net()))
batch_ids = np.ones([8, 8, 8, 8]).astype(np.int32)
ds_train = DatasetLenet(Tensor(batch_ids), None)
model = Model(net)
model.train(1, ds_train, dataset_sink_mode=False)
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op3': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/ReLU-op4': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op5': [[1, 1, 1, 1], [1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Mul-op6': [[1, 1, 1, 1], [1, 1, 1, 1]],
'Default/network-NetWithLoss/net-Net/Cast-op1': [[1, 1, 1, 1]],
'Default/network-NetWithLoss/ReduceSum-op7': [[1, 1, 1, 1]]}
assert strategies == expected_strategies
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册