未验证 提交 e05a7a49 编写于 作者: T tangwei12 提交者: GitHub

ut fix (#33102)


Change-Id: I2e82dfcee6a1d0512b94cebc32281123fa5bf597

* pretty print for datafeed error

Change-Id: I056a8b6f03608e96679a83846c97aed289cef7e6

* fix fleet dist infer ut
上级 865f0c1f
......@@ -638,25 +638,34 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get();
std::string line = std::string(str);
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE_NE(
num, 0,
platform::errors::InvalidArgument(
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s, \n Specifically, "
"something wrong happened(the length of this slot's feasign is 0)"
"when we parse the %d th slots."
"Maybe something wrong around this slot"
"\nWe detect the feasign number of this slot is %d, "
"which is illegal.",
str, i, num));
if (num <= 0) {
std::stringstream ss;
ss << "\n\nGot unexpected input, maybe something wrong with it.\n";
ss << "\n----------------------\n";
ss << "The Origin Input Data:\n";
ss << "----------------------\n";
ss << line << "\n";
ss << "\n----------------------\n";
ss << "Some Possible Errors:\n";
ss << "----------------------\n";
ss << "1. The number of ids can not be zero, you need padding.\n";
ss << "2. The input data contains unresolvable characters.\n";
ss << "3. We detect the slot " << i << "'s feasign number is " << num
<< " which is illegal.\n";
ss << "\n";
PADDLE_THROW(platform::errors::InvalidArgument(ss.str()));
}
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
......
......@@ -230,6 +230,10 @@ class TestDistCTR2x2(FleetDistRunnerBase):
except fluid.core.EOFException:
self.reader.reset()
dirname = os.getenv("SAVE_DIRNAME", None)
if dirname:
fleet.save_persistables(exe, dirname=dirname)
model_dir = tempfile.mkdtemp()
fleet.save_inference_model(
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
......@@ -279,5 +283,9 @@ class TestDistCTR2x2(FleetDistRunnerBase):
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
dirname = os.getenv("SAVE_DIRNAME", None)
if dirname:
fleet.save_persistables(exe, dirname=dirname)
if __name__ == "__main__":
runtime_main(TestDistCTR2x2)
......@@ -167,12 +167,15 @@ half_run_server.run_ut()
_python = sys.executable
ps_cmd = "{} {}".format(_python, server_file)
ps_proc = subprocess.Popen(
ps_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
time.sleep(5)
outs, errs = ps_proc.communicate(timeout=15)
time.sleep(1)
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["http_proxy"] = ""
......@@ -180,6 +183,7 @@ half_run_server.run_ut()
self.run_ut()
ps_proc.kill()
ps_proc.wait()
if os.path.exists(server_file):
os.remove(server_file)
......
......@@ -241,42 +241,72 @@ class TestFleetBase(unittest.TestCase):
def _start_pserver(self, cmd, required_envs):
ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)
ps0_pipe = open(tempfile.gettempdir() + "/ps0_err.log", "wb+")
ps1_pipe = open(tempfile.gettempdir() + "/ps1_err.log", "wb+")
log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
log_prename = required_envs.get("LOG_PREFIX", "")
if log_dirname:
log_prename += "_"
ps0_err_log = os.path.join(log_dirname, log_prename + "ps0_stderr.log")
ps1_err_log = os.path.join(log_dirname, log_prename + "ps1_stderr.log")
ps0_out_log = os.path.join(log_dirname, log_prename + "ps0_stdout.log")
ps1_out_log = os.path.join(log_dirname, log_prename + "ps1_stdout.log")
ps0_err = open(ps0_err_log, "wb+")
ps1_err = open(ps1_err_log, "wb+")
ps0_out = open(ps0_out_log, "wb+")
ps1_out = open(ps1_out_log, "wb+")
ps0_proc = subprocess.Popen(
ps0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps0_pipe,
stdout=ps0_out,
stderr=ps0_err,
env=required_envs)
ps1_proc = subprocess.Popen(
ps1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps1_pipe,
stdout=ps1_out,
stderr=ps1_err,
env=required_envs)
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
return ((ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log),
(ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log))
def _start_trainer(self, cmd, required_envs):
tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)
tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+")
tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+")
log_dirname = required_envs.get("LOG_DIRNAME", tempfile.gettempdir())
log_prename = required_envs.get("LOG_PREFIX", "")
if log_dirname:
log_prename += "_"
tr0_err_log = os.path.join(log_dirname, log_prename + "tr0_stderr.log")
tr1_err_log = os.path.join(log_dirname, log_prename + "tr1_stderr.log")
tr0_out_log = os.path.join(log_dirname, log_prename + "tr0_stdout.log")
tr1_out_log = os.path.join(log_dirname, log_prename + "tr1_stdout.log")
tr0_out = open(tempfile.gettempdir() + "/tr0_stdout.log", "wb+")
tr1_out = open(tempfile.gettempdir() + "/tr1_stdout.log", "wb+")
tr0_err = open(tr0_err_log, "wb+")
tr1_err = open(tr1_err_log, "wb+")
tr0_out = open(tr0_out_log, "wb+")
tr1_out = open(tr1_out_log, "wb+")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "),
stdout=tr0_out,
stderr=tr0_pipe,
stderr=tr0_err,
env=required_envs)
tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "),
stdout=tr1_out,
stderr=tr1_pipe,
stderr=tr1_err,
env=required_envs)
return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe
return ((tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log),
(tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log))
def _run_cluster(self, model, envs):
env = {'GRAD_CLIP': str(self._grad_clip_mode)}
......@@ -303,57 +333,87 @@ class TestFleetBase(unittest.TestCase):
ps_cmd += " --model_dir {}".format(self._model_dir)
# Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env)
ps0, ps1 = self._start_pserver(ps_cmd, env)
tr0, tr1 = self._start_trainer(tr_cmd, env)
ps0_proc, ps0_out, ps0_err, ps0_out_log, ps0_err_log = ps0
ps1_proc, ps1_out, ps1_err, ps1_out_log, ps1_err_log = ps1
tr0_proc, tr0_out, tr0_err, tr0_out_log, tr0_err_log = tr0
tr1_proc, tr1_out, tr1_err, tr1_out_log, tr1_err_log = tr1
# Wait until trainer process terminate
while True:
stat0 = tr0.poll()
time.sleep(0.1)
if stat0 is not None:
break
time_out = 120
cur_time = 0
while True:
stat1 = tr1.poll()
time.sleep(0.1)
if stat1 is not None:
stat0 = tr0_proc.poll()
stat1 = tr1_proc.poll()
if stat0 is not None and stat1 is not None:
break
else:
time.sleep(0.5)
cur_time += 0.5
if cur_time >= time_out:
tr0_proc.terminate()
tr1_proc.terminate()
tr0_proc.wait()
tr1_proc.wait()
break
tr0_out, tr0_err = tr0.communicate()
tr1_out, tr1_err = tr1.communicate()
tr0_ret = tr0.returncode
tr1_ret = tr0.returncode
if tr0_ret != 0:
print(
"========================Error tr0_err begin==========================="
)
os.system("cat {}".format(tempfile.gettempdir() + "/tr0_err.log"))
print(
"========================Error tr0_err end==========================="
)
if tr1_ret != 0:
print(
"========================Error tr1_err begin==========================="
)
os.system("cat {}".format(tempfile.gettempdir() + "/tr1_err.log"))
print(
"========================Error tr1_err end==========================="
)
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
ps0_pipe.close()
ps1_pipe.close()
ps0.terminate()
ps1.terminate()
tr0_ret = tr0_proc.returncode
tr1_ret = tr1_proc.returncode
ps0_proc.kill()
ps1_proc.kill()
ps0_proc.wait()
ps1_proc.wait()
def is_listen_failed(logx):
is_lf = False
listen_rgx = "Fail to listen"
with open(logx, "r") as rb:
for line in rb.readlines():
if listen_rgx in line:
is_lf = True
break
return is_lf
def catlog(logx):
basename = os.path.basename(logx)
print("\n================== Error {} begin =====================".
format(basename))
os.system("cat {}".format(logx))
print("================== Error {} end =====================\n".
format(basename))
if tr0_ret != 0 or tr1_ret != 0:
if is_listen_failed(ps0_err) or is_listen_failed(ps1_err):
print("find parameter server port bind failed, skip the error")
tr0_ret, tr1_ret = 0, 0
else:
for out, err in [
(ps0_out_log, ps0_err_log), (ps1_out_log, ps1_err_log),
(tr0_out_log, tr0_err_log), (tr1_out_log, tr1_err_log)
]:
catlog(out)
catlog(err)
for pipe in [
tr0_err, tr0_out, tr1_err, tr1_out, ps0_err, ps0_out, ps1_err,
ps1_out
]:
pipe.close()
shutil.rmtree(gloo_path)
self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
return 0, 0
def check_with_place(self,
......@@ -399,6 +459,7 @@ def runtime_main(test_class):
model = test_class()
role = model.build_role(args)
# for distributed inference
if args.test and args.model_dir != "":
avg_cost = model.net(args, is_train=False)
dist_infer = DistributedInfer()
......@@ -407,12 +468,16 @@ def runtime_main(test_class):
loss=model.avg_cost,
role_maker=role,
dirname=args.model_dir)
if fleet.is_worker():
with paddle.static.program_guard(
main_program=dist_infer.get_dist_infer_program()):
model.do_distributed_testing(fleet)
fleet.stop_worker()
return
return
if fleet.is_server():
return
fleet.init(role)
strategy = model.build_strategy(args)
......
......@@ -36,7 +36,9 @@ class TestDistMnistAsync2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -71,7 +73,9 @@ class TestDistCtrHalfAsync2x2(TestFleetBase):
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "2",
"SAVE_MODEL": "0"
"SAVE_MODEL": "0",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......
......@@ -38,7 +38,9 @@ class TestDistMnistSync2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -75,7 +77,9 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase):
"dump_param": "concat_0.tmp_0",
"dump_fields": "dnn-fc-3.tmp_0,dnn-fc-3.tmp_0@GRAD",
"dump_fields_path": tempfile.mkdtemp(),
"Debug": "1"
"Debug": "1",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......
......@@ -42,7 +42,9 @@ class TestDistGeoCtr_2x2(TestFleetBase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
"http_proxy": "",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -55,7 +57,7 @@ class TestDistGeoCtr_2x2(TestFleetBase):
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
"dist_fleet_ctr.py", delta=1e-5, check_error_log=False)
class TestGeoSgdTranspiler(unittest.TestCase):
......
......@@ -27,17 +27,6 @@ class TestDistCtrInfer(TestFleetBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
self._need_test = 1
data_url = "https://fleet.bj.bcebos.com/unittest/ctr_saved_params.tar.gz"
data_md5 = "aa7e8286ced566ea8a67410be7482438"
module_name = "ctr_saved_params"
path = download(data_url, module_name, data_md5)
print('ctr_params is downloaded at ' + path)
tar = tarfile.open(path)
unzip_folder = tempfile.mkdtemp()
tar.extractall(unzip_folder)
self._model_dir = unzip_folder
def check_with_place(self,
model_file,
......@@ -53,6 +42,8 @@ class TestDistCtrInfer(TestFleetBase):
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -64,9 +55,21 @@ class TestDistCtrInfer(TestFleetBase):
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_infer(self):
model_dirname = tempfile.mkdtemp()
self.check_with_place(
"dist_fleet_ctr.py",
delta=1e-5,
check_error_log=False,
need_envs={"SAVE_DIRNAME": model_dirname, })
self._need_test = 1
self._model_dir = model_dirname
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=False)
shutil.rmtree(self._model_dir)
shutil.rmtree(model_dirname)
class TestDistCtrTrainInfer(TestFleetBase):
......@@ -80,6 +83,7 @@ class TestDistCtrTrainInfer(TestFleetBase):
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
......@@ -89,6 +93,8 @@ class TestDistCtrTrainInfer(TestFleetBase):
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......
......@@ -45,7 +45,9 @@ class TestDistMnistSync2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -79,7 +81,9 @@ class TestDistMnistAsync2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2"
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -114,7 +118,9 @@ class TestDistMnistAsync2x2WithDecay(TestFleetBase):
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2",
"DECAY": "0"
"DECAY": "0",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -149,7 +155,9 @@ class TestDistMnistAsync2x2WithUnifrom(TestFleetBase):
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "2",
"INITIALIZER": "1"
"INITIALIZER": "1",
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......@@ -264,6 +272,7 @@ class TestDistMnistAsync2x2WithGauss(TestFleetBase):
check_error_log=False,
need_envs={}):
model_dir = tempfile.mkdtemp()
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
......@@ -272,7 +281,9 @@ class TestDistMnistAsync2x2WithGauss(TestFleetBase):
"http_proxy": "",
"CPU_NUM": "2",
"INITIALIZER": "2",
"MODEL_DIR": model_dir
"MODEL_DIR": model_dir,
"LOG_DIRNAME": "/tmp",
"LOG_PREFIX": self.__class__.__name__,
}
required_envs.update(need_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册