提交 7d4f4818 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5017 remove internal interface in wide&deep

Merge pull request !5017 from yao_yf/wide_and_deep_no_internal_interface
......@@ -17,5 +17,7 @@ This interface is ONLY used in Auto-parallel procedure.
"""
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters
from ._cost_model_context import set_multi_subgraphs, get_multi_subgraphs
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
__all__ = ["set_multi_subgraphs", "get_multi_subgraphs",
"get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
......@@ -479,7 +479,6 @@ set_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().set_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().set_multi_subgraphs,
"run_phase": cost_model_context().set_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times,
......@@ -501,7 +500,6 @@ get_cost_model_context_func_map = {
"costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold,
"costmodel_communi_const": cost_model_context().get_costmodel_communi_const,
"costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias,
"multi_subgraphs": cost_model_context().get_multi_subgraphs,
"run_phase": cost_model_context().get_run_phase,
"costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm,
"costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times,
......@@ -538,7 +536,6 @@ def set_cost_model_context(**kwargs):
costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice.
costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice.
multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs.
run_phase (int): A parameter indicating which phase is running: training (0) or inference (1). Default: 0.
costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm.
0: bypass allreduce fusion;
......@@ -591,3 +588,18 @@ def get_cost_model_context(attr_key):
def reset_cost_model_context():
"""Reset cost model context attributes."""
cost_model_context().reset_cost_model()
def set_multi_subgraphs(multi_subgraph=True):
"""
Set the flag of ANF graph containing multiple subgraphs.
Args:
multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag.
"""
cost_model_context().set_multi_subgraphs(multi_subgraph)
def get_multi_subgraphs():
"""
Get the flag of ANF graph containing multiple subgraphs.
"""
cost_model_context().get_multi_subgraphs()
......@@ -14,7 +14,7 @@
# ============================================================================
"""wide and deep model"""
import numpy as np
from mindspore import nn
from mindspore import nn, context
from mindspore import Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
......@@ -22,10 +22,7 @@ from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL, LazyAdam
# from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
......@@ -142,7 +139,7 @@ class WideDeepModel(nn.Cell):
self.batch_size = config.batch_size
host_device_mix = bool(config.host_device_mix)
parameter_server = bool(config.parameter_server)
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if is_auto_parallel:
self.batch_size = self.batch_size * get_group_size()
......@@ -259,7 +256,7 @@ class NetWithLossClass(nn.Cell):
super(NetWithLossClass, self).__init__(auto_prefix=False)
host_device_mix = bool(config.host_device_mix)
parameter_server = bool(config.parameter_server)
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server)
self.network = network
......@@ -312,7 +309,7 @@ class TrainStepWrap(nn.Cell):
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False):
super(TrainStepWrap, self).__init__()
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.network = network
self.network.set_train()
......@@ -351,12 +348,11 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("mirror_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
......
......@@ -22,7 +22,7 @@ from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......@@ -127,7 +127,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
set_multi_subgraphs()
if wide_deep_config.device_target == "Ascend":
init("hccl")
elif wide_deep_config.device_target == "GPU":
......
......@@ -16,7 +16,7 @@
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import nn
from mindspore import nn, context
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore.ops import functional as F
from mindspore.ops import composite as C
......@@ -24,7 +24,6 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout, Flatten
from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
......@@ -552,13 +551,13 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = False
self.grad_reducer_w = None
self.grad_reducer_d = None
parallel_mode = _get_parallel_mode()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
mean = context.get_auto_parallel_context("mirror_mean")
degree = context.get_auto_parallel_context("device_num")
self.grad_reducer_w = DistributedGradReducer(
self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(
......
......@@ -21,7 +21,7 @@ from mindspore import Model, context
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......@@ -33,7 +33,7 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True)
cost_model_context.set_cost_model_context(multi_subgraphs=True)
set_multi_subgraphs()
init()
......
......@@ -23,7 +23,7 @@ from mindspore.nn.optim import Adam, FTRL
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_multi_subgraphs
from mindspore.parallel._utils import _reset_op_id as reset_op_id
......@@ -103,7 +103,7 @@ class TrainStepWarp(nn.Cell):
def test_double_subgraphs():
cost_model_context.set_cost_model_context(multi_subgraphs=True)
set_multi_subgraphs()
context.set_context(save_graphs=True)
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="auto_parallel")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册