提交 65a9c80a 编写于 作者: L liuyang_655

model_train

上级 521e351d
...@@ -378,8 +378,8 @@ class Model: ...@@ -378,8 +378,8 @@ class Model:
with _CallbackManager(callbacks) as list_callback: with _CallbackManager(callbacks) as list_callback:
if not dataset_sink_mode: if not dataset_sink_mode:
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
elif context.get_context("mode") == context.PYNATIVE_MODE: elif context.get_context("mode") == context.PYNATIVE_MODE or context.get_context("device_target") == "CPU":
logger.warning("The pynative mode cannot support dataset sink mode currently." logger.warning("The pynative mode and CPU cannot support dataset sink mode currently."
"So the training process will be performed with dataset not sink.") "So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params) self._train_process(epoch, train_dataset, list_callback, cb_params)
else: else:
......
...@@ -44,8 +44,6 @@ if __name__ == "__main__": ...@@ -44,8 +44,6 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.device_target == "CPU":
args.dataset_sink_mode = False
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), ds_train = create_dataset(os.path.join(args.data_path, "train"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册