提交 ac62faa3 编写于 作者: J jinyaohui

modify set_dataset_mode_config api param

上级 30c242d7
......@@ -67,7 +67,7 @@ if __name__ == '__main__':
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--mode", type=str, default="graph", help="Run graph mode or feed mode, default is graph")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or non-sink mode, default is sink")
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
......@@ -150,8 +150,8 @@ if __name__ == '__main__':
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "graph":
print("In graph mode, one epoch return a loss.")
if args_opt.mode == "sink":
print("In sink mode, one epoch return a loss.")
dataset_sink_mode = True
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
......@@ -116,7 +116,7 @@ bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batc
return transform::TransformUtil::ConvertDataType(i->type_id());
});
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_GRAPH_MODE);
ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE);
ConfigManager::GetInstance().set_iter_num(size);
ConfigManager::GetInstance().set_dataset_phase(phase);
......@@ -453,8 +453,8 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr>& info, const py::
}
// process the first args of tensor
// only in Dataset Feed Mode, fp_bp graph need input tensors
if (ConfigManager::GetInstance().dataset_mode() == DS_FEED_MODE) {
// only in Dataset non-sink Mode, fp_bp graph need input tensors
if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) {
for (std::size_t i = 0; i < size; i++) {
ValuePtr converted = nullptr;
bool succ = parse::ConvertData(args[i], &converted);
......
......@@ -440,10 +440,10 @@ void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
int64_t value = 0;
auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
value = ConfigManager::GetInstance().iter_num();
} else {
MS_LOG(INFO) << "Run with feed mode, the iterator number will always be 1";
MS_LOG(INFO) << "Run with non-sink mode, the iterator number will always be 1";
value = 1;
ConfigManager::GetInstance().set_iter_num(value);
}
......@@ -574,7 +574,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std
void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) {
MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input";
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
auto getnext_idx = static_cast<int64_t>(input_idx);
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) {
......@@ -866,7 +866,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
}
// Create dataset iterator and iterator_getnext node
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
MS_LOG(INFO) << "Dataset param is " << param.ToString() << ".";
// GetNext
......@@ -975,7 +975,7 @@ void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) {
}
void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
size_t output_num = param.ge_types().size();
MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << ".";
......@@ -1034,7 +1034,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
// set graph input according to the order from anf graph
std::vector<Operator> inputs;
if (ConfigManager::GetInstance().dataset_mode() == DS_GRAPH_MODE) {
if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
inputs.push_back(*dataset_iter_getnext_);
} else {
auto params = anf_graph_->parameters();
......
......@@ -28,7 +28,7 @@ ConfigManager& ConfigManager::GetInstance() noexcept {
}
void ConfigManager::SetDatasetModeConfig(const std::string& mode) {
static const std::map<std::string, DatasetMode> mode_map = {{"feed", DS_FEED_MODE}, {"graph", DS_GRAPH_MODE}};
static const std::map<std::string, DatasetMode> mode_map = {{"normal", DS_NORMAL_MODE}, {"sink", DS_SINK_MODE}};
if (mode_map.find(mode) == mode_map.end()) {
MS_LOG(ERROR) << "Invalid dataset mode:" << mode;
return;
......@@ -38,7 +38,7 @@ void ConfigManager::SetDatasetModeConfig(const std::string& mode) {
void ConfigManager::ResetConfig() noexcept {
parallel_strategy_ = ONE_DEVICE;
dataset_mode_ = DS_FEED_MODE;
dataset_mode_ = DS_NORMAL_MODE;
dataset_param_ = DatasetGraphParam("", 0, 0, {}, {}, {});
iter_num_ = 1;
}
......
......@@ -33,7 +33,7 @@ enum ParallelStrategy {
DISTRIBUTION,
};
enum DatasetMode { DS_FEED_MODE = 0, DS_GRAPH_MODE };
enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE };
class DatasetGraphParam {
public:
......@@ -106,7 +106,7 @@ class ConfigManager {
~ConfigManager() = default;
ParallelStrategy parallel_strategy_{ONE_DEVICE};
DatasetMode dataset_mode_{DS_FEED_MODE};
DatasetMode dataset_mode_{DS_NORMAL_MODE};
DatasetGraphParam dataset_param_{"", 0, 0, {}, {}, {}};
int64_t iter_num_{1};
std::string dataset_phase_{""};
......
......@@ -381,9 +381,9 @@ class _Executor:
if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag:
_set_dataset_mode_config('graph')
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('feed')
_set_dataset_mode_config('normal')
self._build_data_graph(obj, params, phase)
......
......@@ -43,7 +43,7 @@ class DynamicLossScaleUpdateCell(Cell):
In every training step, the loss scaling value will be updated by loss scaling value/`scale_factor`
when there is overflow. And it will be increased by loss scaling value * `scale_factor` if there is no
overflow for a continuous `scale_window` steps. This cell is used for Graph mode training in which all
logic will be executed on device side(Another training mode is feed mode in which some logic will be
logic will be executed on device side(Another training mode is non-sink mode in which some logic will be
executed on host).
Args:
......
......@@ -24,11 +24,12 @@ from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.nn.optim import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext,_checkpoint_cb_for_save_op,\
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist,\
_build_callbacks, CheckpointConfig, _set_cur_net
from mindspore.train.callback import ModelCheckpoint, _check_file_name_prefix, RunContext, _checkpoint_cb_for_save_op, \
LossMonitor, _InternalCallbackParam, _chg_ckpt_file_name_if_same_exist, \
_build_callbacks, CheckpointConfig, _set_cur_net
from mindspore.common.api import ms_function
class Net(nn.Cell):
"""Net definition."""
......@@ -52,6 +53,7 @@ class Net(nn.Cell):
class LossNet(nn.Cell):
""" LossNet definition """
def __init__(self):
super(LossNet, self).__init__()
self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
......@@ -110,8 +112,8 @@ def test_save_checkpoint():
os.remove('./test_files/test_ckpt-model.pkl')
def test_loss_monitor_graph_model():
"""Test lossmonitor Graph model."""
def test_loss_monitor_sink_model():
"""Test loss monitor sink model."""
cb_params = _InternalCallbackParam()
cb_params.cur_epoch_num = 4
cb_params.cur_step_num = 2
......@@ -129,8 +131,8 @@ def test_loss_monitor_graph_model():
callbacklist.end(run_context)
def test_Loss_Monitor_feed_feed_model():
"""Test Loss Monitor feed feed mode."""
def test_loss_monitor_feed_model():
"""Test loss monitor non-sink mode."""
cb_params = _InternalCallbackParam()
run_context = RunContext(cb_params)
loss_cb = LossMonitor(1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册