未验证 提交 97a87bb3 编写于 作者: G gongweibao 提交者: GitHub

Fix transformer unittest. (#13974)

Fix transformer unittest
上级 9cb8738f
...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE) ...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
# TODO: fix this test
#py_test_modules(test_dist_transformer MODULES test_dist_transformer) py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
...@@ -35,7 +35,7 @@ import paddle ...@@ -35,7 +35,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid import core from paddle.fluid import core
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
import paddle.compat as cpt import paddle.compat as cpt
from paddle.compat import long_type from paddle.compat import long_type
...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num): for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
if batch_id >= 5: if batch_id >= RUN_STEP:
break break
feed_list = [] feed_list = []
total_num_token = 0 total_num_token = 0
#if TrainTaskConfig.local:
# lr_rate = lr_scheduler.update_learning_rate()
#for place_id, data_buffer in enumerate(
# split_data(
# data, num_part=dev_count)):
if TrainTaskConfig.local: if TrainTaskConfig.local:
lr_rate = lr_scheduler.update_learning_rate() lr_rate = lr_scheduler.update_learning_rate()
...@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
if batch_id == 0 or batch_id == 4: if TrainTaskConfig.val_file_pattern is not None:
if TrainTaskConfig.val_file_pattern is not None: val_avg_cost, val_ppl = test()
val_avg_cost, val_ppl = test() print("[%f]" % val_avg_cost)
print("[%f]" % val_avg_cost) else:
else: assert (False)
assert (False)
#import transformer_reader as reader #import transformer_reader as reader
...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def run_trainer(self, args): def run_trainer(self, args):
TrainTaskConfig.use_gpu = args.use_cuda TrainTaskConfig.use_gpu = args.use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
if args.is_dist: if args.is_dist:
......
...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase): ...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1e-5) self.check_with_place(
"dist_transformer.py", delta=1e-5, check_error_log=False)
class TestDistTransformer2x2Async(TestDistBase): class TestDistTransformer2x2Async(TestDistBase):
...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase): ...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1.0) self.check_with_place(
"dist_transformer.py", delta=1.0, check_error_log=False)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册