未验证 提交 eca8a579 编写于 作者: Z zlsh80826 提交者: GitHub

Use tempfile to place the temporary files (#42626)

* Use tempfile to place the temporary files

* Revise test_bert to use tempfile for temporary files

* Use tempfile for test_transformer

* Fix test_dataset file race
上级 1ce9c2ba
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import time import time
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -33,14 +34,24 @@ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( ...@@ -33,14 +34,24 @@ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
SEED = 2020 SEED = 2020
STEP_NUM = 10 STEP_NUM = 10
PRINT_STEP = 2 PRINT_STEP = 2
MODEL_SAVE_DIR = "./inference"
MODEL_SAVE_PREFIX = "./inference/bert"
MODEL_FILENAME = "bert" + INFER_MODEL_SUFFIX
PARAMS_FILENAME = "bert" + INFER_PARAMS_SUFFIX
DY_STATE_DICT_SAVE_PATH = "./bert.dygraph"
def train(bert_config, data_reader, to_static): class TestBert(unittest.TestCase):
def setUp(self):
self.bert_config = get_bert_config()
self.data_reader = get_feed_data_reader(self.bert_config)
self.temp_dir = tempfile.TemporaryDirectory()
self.model_save_dir = os.path.join(self.temp_dir.name, 'inference')
self.model_save_prefix = os.path.join(self.model_save_dir, 'bert')
self.model_filename = 'bert' + INFER_MODEL_SUFFIX
self.params_filename = 'bert' + INFER_PARAMS_SUFFIX
self.dy_state_dict_save_path = os.path.join(self.temp_dir.name,
'bert.dygraph')
def tearDown(self):
self.temp_dir.cleanup()
def train(self, bert_config, data_reader, to_static):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
fluid.default_main_program().random_seed = SEED fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED fluid.default_startup_program().random_seed = SEED
...@@ -90,47 +101,44 @@ def train(bert_config, data_reader, to_static): ...@@ -90,47 +101,44 @@ def train(bert_config, data_reader, to_static):
step_idx += 1 step_idx += 1
if step_idx == STEP_NUM: if step_idx == STEP_NUM:
if to_static: if to_static:
fluid.dygraph.jit.save(bert, MODEL_SAVE_PREFIX) fluid.dygraph.jit.save(bert, self.model_save_prefix)
else: else:
fluid.dygraph.save_dygraph(bert.state_dict(), fluid.dygraph.save_dygraph(bert.state_dict(),
DY_STATE_DICT_SAVE_PATH) self.dy_state_dict_save_path)
break break
return loss, ppl return loss, ppl
def train_dygraph(self, bert_config, data_reader):
def train_dygraph(bert_config, data_reader):
program_translator.enable(False) program_translator.enable(False)
return train(bert_config, data_reader, False) return self.train(bert_config, data_reader, False)
def train_static(bert_config, data_reader): def train_static(self, bert_config, data_reader):
program_translator.enable(True) program_translator.enable(True)
return train(bert_config, data_reader, True) return self.train(bert_config, data_reader, True)
def predict_static(self, data):
def predict_static(data):
paddle.enable_static() paddle.enable_static()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# load inference model # load inference model
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
MODEL_SAVE_DIR, self.model_save_dir,
executor=exe, executor=exe,
model_filename=MODEL_FILENAME, model_filename=self.model_filename,
params_filename=PARAMS_FILENAME) params_filename=self.params_filename)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed=dict(zip(feed_target_names, data)), feed=dict(zip(feed_target_names, data)),
fetch_list=fetch_targets) fetch_list=fetch_targets)
return pred_res return pred_res
def predict_dygraph(self, bert_config, data):
def predict_dygraph(bert_config, data):
program_translator.enable(False) program_translator.enable(False)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
bert = PretrainModelLayer( bert = PretrainModelLayer(
config=bert_config, weight_sharing=False, use_fp16=False) config=bert_config, weight_sharing=False, use_fp16=False)
model_dict, _ = fluid.dygraph.load_dygraph(DY_STATE_DICT_SAVE_PATH) model_dict, _ = fluid.dygraph.load_dygraph(
self.dy_state_dict_save_path)
bert.set_dict(model_dict) bert.set_dict(model_dict)
bert.eval() bert.eval()
...@@ -149,10 +157,9 @@ def predict_dygraph(bert_config, data): ...@@ -149,10 +157,9 @@ def predict_dygraph(bert_config, data):
return pred_res return pred_res
def predict_dygraph_jit(self, data):
def predict_dygraph_jit(data):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
bert = fluid.dygraph.jit.load(MODEL_SAVE_PREFIX) bert = fluid.dygraph.jit.load(self.model_save_prefix)
bert.eval() bert.eval()
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = data src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = data
...@@ -162,23 +169,16 @@ def predict_dygraph_jit(data): ...@@ -162,23 +169,16 @@ def predict_dygraph_jit(data):
return pred_res return pred_res
def predict_analysis_inference(self, data):
def predict_analysis_inference(data): output = PredictorTools(self.model_save_dir, self.model_filename,
output = PredictorTools(MODEL_SAVE_DIR, MODEL_FILENAME, PARAMS_FILENAME, self.params_filename, data)
data)
out = output() out = output()
return out return out
class TestBert(unittest.TestCase):
def setUp(self):
self.bert_config = get_bert_config()
self.data_reader = get_feed_data_reader(self.bert_config)
def test_train(self): def test_train(self):
static_loss, static_ppl = train_static(self.bert_config, static_loss, static_ppl = self.train_static(self.bert_config,
self.data_reader) self.data_reader)
dygraph_loss, dygraph_ppl = train_dygraph(self.bert_config, dygraph_loss, dygraph_ppl = self.train_dygraph(self.bert_config,
self.data_reader) self.data_reader)
self.assertTrue( self.assertTrue(
np.allclose(static_loss, dygraph_loss), np.allclose(static_loss, dygraph_loss),
...@@ -193,10 +193,10 @@ class TestBert(unittest.TestCase): ...@@ -193,10 +193,10 @@ class TestBert(unittest.TestCase):
def verify_predict(self): def verify_predict(self):
for data in self.data_reader.data_generator()(): for data in self.data_reader.data_generator()():
dygraph_pred_res = predict_dygraph(self.bert_config, data) dygraph_pred_res = self.predict_dygraph(self.bert_config, data)
static_pred_res = predict_static(data) static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = predict_dygraph_jit(data) dygraph_jit_pred_res = self.predict_dygraph_jit(data)
predictor_pred_res = predict_analysis_inference(data) predictor_pred_res = self.predict_analysis_inference(data)
for dy_res, st_res, dy_jit_res, predictor_res in zip( for dy_res, st_res, dy_jit_res, predictor_res in zip(
dygraph_pred_res, static_pred_res, dygraph_jit_pred_res, dygraph_pred_res, static_pred_res, dygraph_jit_pred_res,
......
...@@ -18,8 +18,7 @@ import unittest ...@@ -18,8 +18,7 @@ import unittest
import numpy as np import numpy as np
from paddle.jit import ProgramTranslator from paddle.jit import ProgramTranslator
from test_resnet import ResNet, train, predict_dygraph_jit from test_resnet import ResNet, ResNetHelper
from test_resnet import predict_dygraph, predict_static, predict_analysis_inference
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
...@@ -31,20 +30,20 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -31,20 +30,20 @@ class TestResnetWithPass(unittest.TestCase):
self.build_strategy.fuse_bn_act_ops = True self.build_strategy.fuse_bn_act_ops = True
self.build_strategy.fuse_bn_add_act_ops = True self.build_strategy.fuse_bn_add_act_ops = True
self.build_strategy.enable_addto = True self.build_strategy.enable_addto = True
self.resnet_helper = ResNetHelper()
# NOTE: for enable_addto # NOTE: for enable_addto
paddle.fluid.set_flags({"FLAGS_max_inplace_grad_add": 8}) paddle.fluid.set_flags({"FLAGS_max_inplace_grad_add": 8})
def train(self, to_static): def train(self, to_static):
program_translator.enable(to_static) program_translator.enable(to_static)
return self.resnet_helper.train(to_static, self.build_strategy)
return train(to_static, self.build_strategy)
def verify_predict(self): def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32') image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = predict_dygraph(image) dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = predict_static(image) st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = predict_dygraph_jit(image) dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = predict_analysis_inference(image) predictor_pre = self.resnet_helper.predict_analysis_inference(image)
self.assertTrue( self.assertTrue(
np.allclose(dy_pre, st_pre), np.allclose(dy_pre, st_pre),
msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre)) msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre))
...@@ -69,7 +68,7 @@ class TestResnetWithPass(unittest.TestCase): ...@@ -69,7 +68,7 @@ class TestResnetWithPass(unittest.TestCase):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try: try:
if paddle.fluid.core.is_compiled_with_mkldnn(): if paddle.fluid.core.is_compiled_with_mkldnn():
train(True, self.build_strategy) self.resnet_helper.train(True, self.build_strategy)
finally: finally:
paddle.fluid.set_flags({'FLAGS_use_mkldnn': False}) paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
from __future__ import print_function from __future__ import print_function
import os
import math import math
import time import time
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -39,11 +41,6 @@ epoch_num = 1 ...@@ -39,11 +41,6 @@ epoch_num = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
MODEL_SAVE_DIR = "./inference"
MODEL_SAVE_PREFIX = "./inference/resnet"
MODEL_FILENAME = "resnet" + INFER_MODEL_SUFFIX
PARAMS_FILENAME = "resnet" + INFER_PARAMS_SUFFIX
DY_STATE_DICT_SAVE_PATH = "./resnet.dygraph"
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
if fluid.is_compiled_with_cuda(): if fluid.is_compiled_with_cuda():
...@@ -212,7 +209,20 @@ def reader_decorator(reader): ...@@ -212,7 +209,20 @@ def reader_decorator(reader):
return __reader__ return __reader__
def train(to_static, build_strategy=None): class ResNetHelper:
def __init__(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.model_save_dir = os.path.join(self.temp_dir.name, 'inference')
self.model_save_prefix = os.path.join(self.model_save_dir, 'resnet')
self.model_filename = 'resnet' + INFER_MODEL_SUFFIX
self.params_filename = 'resnet' + INFER_PARAMS_SUFFIX
self.dy_state_dict_save_path = os.path.join(self.temp_dir.name,
'resnet.dygraph')
def __del__(self):
self.temp_dir.cleanup()
def train(self, to_static, build_strategy=None):
""" """
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode. Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
""" """
...@@ -231,7 +241,8 @@ def train(to_static, build_strategy=None): ...@@ -231,7 +241,8 @@ def train(to_static, build_strategy=None):
resnet = ResNet() resnet = ResNet()
if to_static: if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy) resnet = paddle.jit.to_static(
resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters()) optimizer = optimizer_setting(parameter_list=resnet.parameters())
for epoch in range(epoch_num): for epoch in range(epoch_num):
...@@ -247,8 +258,10 @@ def train(to_static, build_strategy=None): ...@@ -247,8 +258,10 @@ def train(to_static, build_strategy=None):
pred = resnet(img) pred = resnet(img)
loss = fluid.layers.cross_entropy(input=pred, label=label) loss = fluid.layers.cross_entropy(input=pred, label=label)
avg_loss = fluid.layers.mean(x=loss) avg_loss = fluid.layers.mean(x=loss)
acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1) acc_top1 = fluid.layers.accuracy(
acc_top5 = fluid.layers.accuracy(input=pred, label=label, k=5) input=pred, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=pred, label=label, k=5)
avg_loss.backward() avg_loss.backward()
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
...@@ -266,23 +279,25 @@ def train(to_static, build_strategy=None): ...@@ -266,23 +279,25 @@ def train(to_static, build_strategy=None):
total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time)) total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time))
if batch_id == 10: if batch_id == 10:
if to_static: if to_static:
fluid.dygraph.jit.save(resnet, MODEL_SAVE_PREFIX) fluid.dygraph.jit.save(resnet,
self.model_save_prefix)
else: else:
fluid.dygraph.save_dygraph(resnet.state_dict(), fluid.dygraph.save_dygraph(
DY_STATE_DICT_SAVE_PATH) resnet.state_dict(),
self.dy_state_dict_save_path)
# avoid dataloader throw abort signaal # avoid dataloader throw abort signaal
data_loader._reset() data_loader._reset()
break break
return total_loss.numpy() return total_loss.numpy()
def predict_dygraph(self, data):
def predict_dygraph(data):
program_translator.enable(False) program_translator.enable(False)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
resnet = ResNet() resnet = ResNet()
model_dict, _ = fluid.dygraph.load_dygraph(DY_STATE_DICT_SAVE_PATH) model_dict, _ = fluid.dygraph.load_dygraph(
self.dy_state_dict_save_path)
resnet.set_dict(model_dict) resnet.set_dict(model_dict)
resnet.eval() resnet.eval()
...@@ -290,16 +305,15 @@ def predict_dygraph(data): ...@@ -290,16 +305,15 @@ def predict_dygraph(data):
return pred_res.numpy() return pred_res.numpy()
def predict_static(self, data):
def predict_static(data):
paddle.enable_static() paddle.enable_static()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
MODEL_SAVE_DIR, self.model_save_dir,
executor=exe, executor=exe,
model_filename=MODEL_FILENAME, model_filename=self.model_filename,
params_filename=PARAMS_FILENAME) params_filename=self.params_filename)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data}, feed={feed_target_names[0]: data},
...@@ -307,35 +321,36 @@ def predict_static(data): ...@@ -307,35 +321,36 @@ def predict_static(data):
return pred_res[0] return pred_res[0]
def predict_dygraph_jit(self, data):
def predict_dygraph_jit(data):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
resnet = fluid.dygraph.jit.load(MODEL_SAVE_PREFIX) resnet = fluid.dygraph.jit.load(self.model_save_prefix)
resnet.eval() resnet.eval()
pred_res = resnet(data) pred_res = resnet(data)
return pred_res.numpy() return pred_res.numpy()
def predict_analysis_inference(self, data):
def predict_analysis_inference(data): output = PredictorTools(self.model_save_dir, self.model_filename,
output = PredictorTools(MODEL_SAVE_DIR, MODEL_FILENAME, PARAMS_FILENAME, self.params_filename, [data])
[data])
out = output() out = output()
return out return out
class TestResnet(unittest.TestCase): class TestResnet(unittest.TestCase):
def setUp(self):
self.resnet_helper = ResNetHelper()
def train(self, to_static): def train(self, to_static):
program_translator.enable(to_static) program_translator.enable(to_static)
return train(to_static) return self.resnet_helper.train(to_static)
def verify_predict(self): def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32') image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = predict_dygraph(image) dy_pre = self.resnet_helper.predict_dygraph(image)
st_pre = predict_static(image) st_pre = self.resnet_helper.predict_static(image)
dy_jit_pre = predict_dygraph_jit(image) dy_jit_pre = self.resnet_helper.predict_dygraph_jit(image)
predictor_pre = predict_analysis_inference(image) predictor_pre = self.resnet_helper.predict_analysis_inference(image)
self.assertTrue( self.assertTrue(
np.allclose(dy_pre, st_pre), np.allclose(dy_pre, st_pre),
msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre)) msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre))
...@@ -360,7 +375,7 @@ class TestResnet(unittest.TestCase): ...@@ -360,7 +375,7 @@ class TestResnet(unittest.TestCase):
fluid.set_flags({'FLAGS_use_mkldnn': True}) fluid.set_flags({'FLAGS_use_mkldnn': True})
try: try:
if paddle.fluid.core.is_compiled_with_mkldnn(): if paddle.fluid.core.is_compiled_with_mkldnn():
train(to_static=True) self.resnet_helper.train(to_static=True)
finally: finally:
fluid.set_flags({'FLAGS_use_mkldnn': False}) fluid.set_flags({'FLAGS_use_mkldnn': False})
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import logging import logging
import os import os
import time import time
import tempfile
import unittest import unittest
import numpy as np import numpy as np
...@@ -371,8 +372,21 @@ def predict_static(args, batch_generator): ...@@ -371,8 +372,21 @@ def predict_static(args, batch_generator):
class TestTransformer(unittest.TestCase): class TestTransformer(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDwon(self):
self.temp_dir.cleanup()
def prepare(self, mode='train'): def prepare(self, mode='train'):
args = util.ModelHyperParams() args = util.ModelHyperParams()
args.save_dygraph_model_path = os.path.join(
self.temp_dir.name, args.save_dygraph_model_path)
args.save_static_model_path = os.path.join(self.temp_dir.name,
args.save_static_model_path)
args.inference_model_dir = os.path.join(self.temp_dir.name,
args.inference_model_dir)
args.output_file = os.path.join(self.temp_dir.name, args.output_file)
batch_generator = util.get_feed_data_reader(args, mode) batch_generator = util.get_feed_data_reader(args, mode)
return args, batch_generator return args, batch_generator
......
...@@ -24,6 +24,7 @@ import paddle.fluid.core as core ...@@ -24,6 +24,7 @@ import paddle.fluid.core as core
import numpy as np import numpy as np
import os import os
import shutil import shutil
import tempfile
import unittest import unittest
...@@ -82,12 +83,17 @@ class TestDataset(unittest.TestCase): ...@@ -82,12 +83,17 @@ class TestDataset(unittest.TestCase):
""" """
Testcase for InMemoryDataset from create to run. Testcase for InMemoryDataset from create to run.
""" """
with open("test_run_with_dump_a.txt", "w") as f:
temp_dir = tempfile.TemporaryDirectory()
dump_a_path = os.path.join(temp_dir.name, 'test_run_with_dump_a.txt')
dump_b_path = os.path.join(temp_dir.name, 'test_run_with_dump_b.txt')
with open(dump_a_path, "w") as f:
data = "1 a 1 a 1 1 2 3 3 4 5 5 5 5 1 1\n" data = "1 a 1 a 1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 b 1 b 1 2 2 3 4 4 6 6 6 6 1 2\n" data += "1 b 1 b 1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 c 1 c 1 3 2 3 5 4 7 7 7 7 1 3\n" data += "1 c 1 c 1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data) f.write(data)
with open("test_run_with_dump_b.txt", "w") as f: with open(dump_b_path, "w") as f:
data = "1 d 1 d 1 4 2 3 3 4 5 5 5 5 1 4\n" data = "1 d 1 d 1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 e 1 e 1 5 2 3 4 4 6 6 6 6 1 5\n" data += "1 e 1 e 1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 f 1 f 1 6 2 3 5 4 7 7 7 7 1 6\n" data += "1 f 1 f 1 6 2 3 5 4 7 7 7 7 1 6\n"
...@@ -110,8 +116,7 @@ class TestDataset(unittest.TestCase): ...@@ -110,8 +116,7 @@ class TestDataset(unittest.TestCase):
parse_content=True, parse_content=True,
fea_eval=True, fea_eval=True,
candidate_size=10000) candidate_size=10000)
dataset.set_filelist( dataset.set_filelist([dump_a_path, dump_b_path])
["test_run_with_dump_a.txt", "test_run_with_dump_b.txt"])
dataset.load_into_memory() dataset.load_into_memory()
dataset.local_shuffle() dataset.local_shuffle()
...@@ -129,8 +134,7 @@ class TestDataset(unittest.TestCase): ...@@ -129,8 +134,7 @@ class TestDataset(unittest.TestCase):
except Exception as e: except Exception as e:
self.assertTrue(False) self.assertTrue(False)
os.remove("./test_run_with_dump_a.txt") temp_dir.cleanup()
os.remove("./test_run_with_dump_b.txt")
def test_dataset_config(self): def test_dataset_config(self):
""" Testcase for dataset configuration. """ """ Testcase for dataset configuration. """
......
...@@ -25,6 +25,7 @@ import random ...@@ -25,6 +25,7 @@ import random
import math import math
import os import os
import shutil import shutil
import tempfile
import unittest import unittest
import paddle.fluid.incubate.data_generator as dg import paddle.fluid.incubate.data_generator as dg
...@@ -282,7 +283,11 @@ class TestDataset(unittest.TestCase): ...@@ -282,7 +283,11 @@ class TestDataset(unittest.TestCase):
""" """
Testcase for InMemoryDataset of consistency insepection of use_var_list and data_generator. Testcase for InMemoryDataset of consistency insepection of use_var_list and data_generator.
""" """
with open("test_run_with_dump_a.txt", "w") as f:
temp_dir = tempfile.TemporaryDirectory()
dump_a_path = os.path.join(temp_dir.name, 'test_run_with_dump_a.txt')
with open(dump_a_path, "w") as f:
# data = "\n" # data = "\n"
# data += "\n" # data += "\n"
data = "2 1;1 9;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;0;40000001;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000200;10000200;10063938;10000008;10000177;20002001 20001240 20001860 20003611 20010833 20000210 20000500 20000401 20000251 20012198 20001023 20000157;20002001 20001240 20001860 20003611 20012396 20000500 20002513 20012198 20001023 20000157;10000123;30000004;0.623 0.233 0.290 0.208 0.354 49.000 0.000 0.000 0.000 -1.000 0.569 0.679 0.733 53 17 2 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000047;30000004;0.067 0.000 0.161 0.005 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.378 0.043 0 6 0 0;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;10000200;30000001;0.407 0.111 0.196 0.095 0.181 49.000 0.000 0.000 0.000 -1.000 0.306 0.538 0.355 48 8 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;10000200;30000001;0.226 0.029 0.149 0.031 0.074 49.000 0.000 0.000 0.000 -1.000 0.220 0.531 0.286 26 6 0 0;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;10063938;30000001;0.250 0.019 0.138 0.012 0.027 49.000 0.000 0.000 0.000 -1.000 0.370 0.449 0.327 7 2 0 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000003;30000002;0.056 0.000 0.139 0.003 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.346 0.059 15 3 0 0;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;10000008;30000001;0.166 0.004 0.127 0.001 0.004 49.000 0.000 0.000 0.000 -1.000 0.103 0.417 0.394 10 3 0 0;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000177;30000001;0.094 0.008 0.157 0.012 0.059 49.000 0.000 0.000 0.000 -1.000 0.051 0.382 0.142 21 0 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;10000134;30000001;0.220 0.016 0.181 0.037 0.098 49.000 0.000 0.000 0.000 -1.000 0.192 0.453 0.199 17 1 0 0;20002001 20001240 20001860 20003611 20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000638;30000001;0.000 0.000 0.000 0.000 0.000 49.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0 0 0 0;\n" data = "2 1;1 9;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;0;40000001;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000200;10000200;10063938;10000008;10000177;20002001 20001240 20001860 20003611 20010833 20000210 20000500 20000401 20000251 20012198 20001023 20000157;20002001 20001240 20001860 20003611 20012396 20000500 20002513 20012198 20001023 20000157;10000123;30000004;0.623 0.233 0.290 0.208 0.354 49.000 0.000 0.000 0.000 -1.000 0.569 0.679 0.733 53 17 2 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000047;30000004;0.067 0.000 0.161 0.005 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.378 0.043 0 6 0 0;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20000723 20000070 20003519 20000005;10000200;30000001;0.407 0.111 0.196 0.095 0.181 49.000 0.000 0.000 0.000 -1.000 0.306 0.538 0.355 48 8 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20002616 20000157 20000005;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20003519 20000005;10000200;30000001;0.226 0.029 0.149 0.031 0.074 49.000 0.000 0.000 0.000 -1.000 0.220 0.531 0.286 26 6 0 0;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20012788 20000157;20002001 20001240 20001860 20003611 20000723 20000070 20002001 20001240 20001860 20003611 20131464;10063938;30000001;0.250 0.019 0.138 0.012 0.027 49.000 0.000 0.000 0.000 -1.000 0.370 0.449 0.327 7 2 0 0;20002001 20001240 20001860 20003611 20000723;20002001 20001240 20001860 20003611 20000723;10000003;30000002;0.056 0.000 0.139 0.003 0.000 49.000 0.000 0.000 0.000 -1.000 0.000 0.346 0.059 15 3 0 0;20002001 20001240 20001860 20003611 20000623 20000251 20000157 20000723 20000070 20000001 20000057;20002001 20001240 20001860 20003611 20018820 20000157 20000723 20000070 20000001 20000057;10000008;30000001;0.166 0.004 0.127 0.001 0.004 49.000 0.000 0.000 0.000 -1.000 0.103 0.417 0.394 10 3 0 0;20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000177;30000001;0.094 0.008 0.157 0.012 0.059 49.000 0.000 0.000 0.000 -1.000 0.051 0.382 0.142 21 0 0 0;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;20002001 20001240 20001860 20003611 20000157 20001776 20000070 20000157;10000134;30000001;0.220 0.016 0.181 0.037 0.098 49.000 0.000 0.000 0.000 -1.000 0.192 0.453 0.199 17 1 0 0;20002001 20001240 20001860 20003611 20002640 20004695 20000157 20000723 20000070 20002001 20001240 20001860 20003611;20002001 20001240 20001860 20003611 20002640 20034154 20000723 20000070 20002001 20001240 20001860 20003611;10000638;30000001;0.000 0.000 0.000 0.000 0.000 49.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 0 0 0 0;\n"
...@@ -348,7 +353,7 @@ class TestDataset(unittest.TestCase): ...@@ -348,7 +353,7 @@ class TestDataset(unittest.TestCase):
generator_class = CTRDataset(mode=0) generator_class = CTRDataset(mode=0)
try: try:
dataset._check_use_var_with_data_generator( dataset._check_use_var_with_data_generator(
slot_data, generator_class, "test_run_with_dump_a.txt") slot_data, generator_class, dump_a_path)
print("case 1: check passed!") print("case 1: check passed!")
except Exception as e: except Exception as e:
print("warning: catch expected error") print("warning: catch expected error")
...@@ -360,7 +365,7 @@ class TestDataset(unittest.TestCase): ...@@ -360,7 +365,7 @@ class TestDataset(unittest.TestCase):
generator_class = CTRDataset(mode=2) generator_class = CTRDataset(mode=2)
try: try:
dataset._check_use_var_with_data_generator( dataset._check_use_var_with_data_generator(
slot_data, generator_class, "test_run_with_dump_a.txt") slot_data, generator_class, dump_a_path)
except Exception as e: except Exception as e:
print("warning: case 2 catch expected error") print("warning: case 2 catch expected error")
print(e) print(e)
...@@ -371,7 +376,7 @@ class TestDataset(unittest.TestCase): ...@@ -371,7 +376,7 @@ class TestDataset(unittest.TestCase):
generator_class = CTRDataset(mode=3) generator_class = CTRDataset(mode=3)
try: try:
dataset._check_use_var_with_data_generator( dataset._check_use_var_with_data_generator(
slot_data, generator_class, "test_run_with_dump_a.txt") slot_data, generator_class, dump_a_path)
except Exception as e: except Exception as e:
print("warning: case 3 catch expected error") print("warning: case 3 catch expected error")
print(e) print(e)
...@@ -382,7 +387,7 @@ class TestDataset(unittest.TestCase): ...@@ -382,7 +387,7 @@ class TestDataset(unittest.TestCase):
generator_class = CTRDataset(mode=4) generator_class = CTRDataset(mode=4)
try: try:
dataset._check_use_var_with_data_generator( dataset._check_use_var_with_data_generator(
slot_data, generator_class, "test_run_with_dump_a.txt") slot_data, generator_class, dump_a_path)
except Exception as e: except Exception as e:
print("warning: case 4 catch expected error") print("warning: case 4 catch expected error")
print(e) print(e)
...@@ -393,13 +398,13 @@ class TestDataset(unittest.TestCase): ...@@ -393,13 +398,13 @@ class TestDataset(unittest.TestCase):
generator_class = CTRDataset(mode=5) generator_class = CTRDataset(mode=5)
try: try:
dataset._check_use_var_with_data_generator( dataset._check_use_var_with_data_generator(
slot_data, generator_class, "test_run_with_dump_a.txt") slot_data, generator_class, dump_a_path)
except Exception as e: except Exception as e:
print("warning: case 5 catch expected error") print("warning: case 5 catch expected error")
print(e) print(e)
print("========================================") print("========================================")
os.remove("./test_run_with_dump_a.txt") temp_dir.cleanup()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册