提交 a64d3f0d 编写于 作者: J jinyaohui

sink_graph

上级 948d07ed
......@@ -86,7 +86,10 @@ class ConfigManager {
DatasetMode dataset_mode() const { return dataset_mode_; }
void set_dataset_mode(DatasetMode mode) { dataset_mode_ = mode; }
int64_t iter_num() const { return iter_num_; }
int64_t iter_num() const {
if (dataset_mode_ == DS_NORMAL_MODE) return 1;
return iter_num_;
}
void set_iter_num(const int64_t num) { iter_num_ = num; }
std::string dataset_phase() const { return dataset_phase_; }
......
......@@ -338,6 +338,15 @@ class _Executor:
param.init_data(layout)
obj.init_parameters_data()
def _set_dataset_mode(self, args_list):
"""set dataset mode."""
# decide whether to sink based on whether the inputs is virtual or args_list is ()
if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
(args_list is not None and args_list == ()):
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('normal')
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
"""
Compiles graph.
......@@ -368,6 +377,8 @@ class _Executor:
use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
self._set_dataset_mode(args_list)
if phase in self.compile_cache.keys():
logger.debug("%r graph has existed.", phase)
return phase, False
......@@ -396,12 +407,6 @@ class _Executor:
# the following GE init process is not needed when use vm or ms backend
if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not
if args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag:
_set_dataset_mode_config('sink')
else:
_set_dataset_mode_config('normal')
self._build_data_graph(obj, params, phase)
if "export" not in phase:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册