未验证 提交 12fc8c82 编写于 作者: C Chengmo 提交者: GitHub

Fix k8s_datasplit & save_model (#167)

* fix save inference
上级 d4a280b5
...@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/" ...@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/"
# 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193 # 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193
PADDLE_PADDLEREC_ROLE=WORKER PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=K8S
use_python3=<$ USE_PYTHON3 $> use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $> CPU_NUM=<$ CPU_NUM $>
GLOG_v=0 GLOG_v=0
......
...@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $> ...@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $>
thirdparty_path=<$ THIRDPARTY_PATH $> thirdparty_path=<$ THIRDPARTY_PATH $>
PADDLE_PADDLEREC_ROLE=WORKER PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=MPI
use_python3=<$ USE_PYTHON3 $> use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $> CPU_NUM=<$ CPU_NUM $>
GLOG_v=0 GLOG_v=0
......
...@@ -107,6 +107,7 @@ class Trainer(object): ...@@ -107,6 +107,7 @@ class Trainer(object):
self.device = Device.GPU self.device = Device.GPU
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
self._place = fluid.CUDAPlace(gpu_id) self._place = fluid.CUDAPlace(gpu_id)
print("PaddleRec run on device GPU: {}".format(gpu_id))
self._exe = fluid.Executor(self._place) self._exe = fluid.Executor(self._place)
elif device == "CPU": elif device == "CPU":
self.device = Device.CPU self.device = Device.CPU
...@@ -146,6 +147,7 @@ class Trainer(object): ...@@ -146,6 +147,7 @@ class Trainer(object):
elif engine.upper() == "CLUSTER": elif engine.upper() == "CLUSTER":
self.engine = EngineMode.CLUSTER self.engine = EngineMode.CLUSTER
self.is_fleet = True self.is_fleet = True
self.which_cluster_type()
else: else:
raise ValueError("Not Support Engine {}".format(engine)) raise ValueError("Not Support Engine {}".format(engine))
self._context["is_fleet"] = self.is_fleet self._context["is_fleet"] = self.is_fleet
...@@ -165,6 +167,14 @@ class Trainer(object): ...@@ -165,6 +167,14 @@ class Trainer(object):
self._context["is_pslib"] = (fleet_mode.upper() == "PSLIB") self._context["is_pslib"] = (fleet_mode.upper() == "PSLIB")
self._context["fleet_mode"] = fleet_mode self._context["fleet_mode"] = fleet_mode
def which_cluster_type(self):
cluster_type = os.getenv("PADDLEREC_CLUSTER_TYPE", "MPI")
print("PADDLEREC_CLUSTER_TYPE: {}".format(cluster_type))
if cluster_type and cluster_type.upper() == "K8S":
self._context["cluster_type"] = "K8S"
else:
self._context["cluster_type"] = "MPI"
def which_executor_mode(self): def which_executor_mode(self):
executor_mode = envs.get_runtime_environ("train.trainer.executor_mode") executor_mode = envs.get_runtime_environ("train.trainer.executor_mode")
if executor_mode.upper() not in ["TRAIN", "INFER"]: if executor_mode.upper() not in ["TRAIN", "INFER"]:
......
...@@ -123,10 +123,21 @@ class QueueDataset(DatasetBase): ...@@ -123,10 +123,21 @@ class QueueDataset(DatasetBase):
os.path.join(train_data_path, x) os.path.join(train_data_path, x)
for x in os.listdir(train_data_path) for x in os.listdir(train_data_path)
] ]
file_list.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount afs, split files for every node
need_split_files = True
if need_split_files:
file_list = split_files(file_list, context["fleet"].worker_index(), file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
print("File_list: {}".format(file_list)) print("File_list: {}".format(file_list))
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
for model_dict in context["phases"]: for model_dict in context["phases"]:
if model_dict["dataset_name"] == dataset_name: if model_dict["dataset_name"] == dataset_name:
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import os import os
import time import time
import warnings
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -284,6 +285,7 @@ class RunnerBase(object): ...@@ -284,6 +285,7 @@ class RunnerBase(object):
return (epoch_id + 1) % epoch_interval == 0 return (epoch_id + 1) % epoch_interval == 0
def save_inference_model(): def save_inference_model():
# get global env
name = "runner." + context["runner_name"] + "." name = "runner." + context["runner_name"] + "."
save_interval = int( save_interval = int(
envs.get_global_env(name + "save_inference_interval", -1)) envs.get_global_env(name + "save_inference_interval", -1))
...@@ -296,18 +298,44 @@ class RunnerBase(object): ...@@ -296,18 +298,44 @@ class RunnerBase(object):
if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \ if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
len(feed_varnames) == 0 or len(fetch_varnames) == 0: len(feed_varnames) == 0 or len(fetch_varnames) == 0:
return return
fetch_vars = [
fluid.default_main_program().global_block().vars[varname] # check feed var exist
for varname in fetch_varnames for var_name in feed_varnames:
] if var_name not in fluid.default_main_program().global_block(
).vars:
raise ValueError(
"Feed variable: {} not in default_main_program, global block has follow vars: {}".
format(var_name,
fluid.default_main_program().global_block()
.vars.keys()))
# check fetch var exist
fetch_vars = []
for var_name in fetch_varnames:
if var_name not in fluid.default_main_program().global_block(
).vars:
raise ValueError(
"Fetch variable: {} not in default_main_program, global block has follow vars: {}".
format(var_name,
fluid.default_main_program().global_block()
.vars.keys()))
else:
fetch_vars.append(fluid.default_main_program()
.global_block().vars[var_name])
dirname = envs.get_global_env(name + "save_inference_path", None) dirname = envs.get_global_env(name + "save_inference_path", None)
assert dirname is not None assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
context["fleet"].save_inference_model( warnings.warn(
context["exe"], dirname, feed_varnames, fetch_vars) "Save inference model in cluster training is not recommended! Using save checkpoint instead.",
category=UserWarning,
stacklevel=2)
if context["fleet"].worker_index() == 0:
context["fleet"].save_inference_model(
context["exe"], dirname, feed_varnames, fetch_vars)
else: else:
fluid.io.save_inference_model(dirname, feed_varnames, fluid.io.save_inference_model(dirname, feed_varnames,
fetch_vars, context["exe"]) fetch_vars, context["exe"])
...@@ -323,7 +351,8 @@ class RunnerBase(object): ...@@ -323,7 +351,8 @@ class RunnerBase(object):
return return
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
context["fleet"].save_persistables(context["exe"], dirname) if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname)
else: else:
fluid.io.save_persistables(context["exe"], dirname) fluid.io.save_persistables(context["exe"], dirname)
......
...@@ -39,9 +39,21 @@ def dataloader_by_name(readerclass, ...@@ -39,9 +39,21 @@ def dataloader_by_name(readerclass,
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
print("need_split_files: {}".format(need_split_files))
if need_split_files:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
print("file_list : {}".format(files)) print("file_list : {}".format(files))
reader = reader_class(yaml_file) reader = reader_class(yaml_file)
...@@ -81,10 +93,20 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context): ...@@ -81,10 +93,20 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
if need_split_files:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env(name + "sparse_slots", "#") sparse = get_global_env(name + "sparse_slots", "#")
if sparse == "": if sparse == "":
...@@ -135,10 +157,20 @@ def slotdataloader(readerclass, train, yaml_file, context): ...@@ -135,10 +157,20 @@ def slotdataloader(readerclass, train, yaml_file, context):
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER: if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
if need_split_files:
files = split_files(files, context["fleet"].worker_index(), files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num()) context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env("sparse_slots", "#", namespace) sparse = get_global_env("sparse_slots", "#", namespace)
if sparse == "": if sparse == "":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册