提交 c3990215 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1597 add _set_dataset_mode_config to vm

Merge pull request !1597 from jinyaohui/sink_graph
...@@ -86,7 +86,10 @@ class ConfigManager { ...@@ -86,7 +86,10 @@ class ConfigManager {
DatasetMode dataset_mode() const { return dataset_mode_; } DatasetMode dataset_mode() const { return dataset_mode_; }
void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; } void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
int64_t iter_num() const { return iter_num_; } int64_t iter_num() const {
if (dataset_mode_ == DS_NORMAL_MODE) return 1;
return iter_num_;
}
void set_iter_num(const int64_t num) { iter_num_ = num; } void set_iter_num(const int64_t num) { iter_num_ = num; }
std::string dataset_phase() const { return dataset_phase_; } std::string dataset_phase() const { return dataset_phase_; }
......
...@@ -341,6 +341,15 @@ class _Executor: ...@@ -341,6 +341,15 @@ class _Executor:
param.init_data(layout, set_sliced=True) param.init_data(layout, set_sliced=True)
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
# decide whether to sink based on whether the inputs is virtual or args_list is ()
if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
(args_list is not None and args_list == ()):
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('normal')
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
""" """
Compiles graph. Compiles graph.
...@@ -371,6 +380,8 @@ class _Executor: ...@@ -371,6 +380,8 @@ class _Executor:
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
self._set_dataset_mode(args_list)
if phase in self.compile_cache.keys(): if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase) logger.debug("%r graph has existed.", phase)
return phase, False return phase, False
...@@ -399,12 +410,6 @@ class _Executor: ...@@ -399,12 +410,6 @@ class _Executor:
# the following GE init process is not needed when use vm or ms backend # the following GE init process is not needed when use vm or ms backend
if enable_ge: if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag:
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('normal')
self._build_data_graph(obj, params, phase) self._build_data_graph(obj, params, phase)
if "export" not in phase: if "export" not in phase:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册