未验证 提交 97a87bb3 编写于 作者: G gongweibao 提交者: GitHub

Fix transformer unittest. (#13974)

Fix transformer unittest
上级 9cb8738f
......@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
# TODO: fix this test
#py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
py_test_modules(test_dist_transformer MODULES test_dist_transformer)
set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif()
......
......@@ -35,7 +35,7 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
from test_dist_base import TestDistRunnerBase, runtime_main
from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
import paddle.compat as cpt
from paddle.compat import long_type
......@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
if batch_id >= 5:
if batch_id >= RUN_STEP:
break
feed_list = []
total_num_token = 0
#if TrainTaskConfig.local:
# lr_rate = lr_scheduler.update_learning_rate()
#for place_id, data_buffer in enumerate(
# split_data(
# data, num_part=dev_count)):
if TrainTaskConfig.local:
lr_rate = lr_scheduler.update_learning_rate()
......@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init = True
# Validate and save the model for inference.
if batch_id == 0 or batch_id == 4:
if TrainTaskConfig.val_file_pattern is not None:
val_avg_cost, val_ppl = test()
print("[%f]" % val_avg_cost)
else:
assert (False)
if TrainTaskConfig.val_file_pattern is not None:
val_avg_cost, val_ppl = test()
print("[%f]" % val_avg_cost)
else:
assert (False)
#import transformer_reader as reader
......@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def run_trainer(self, args):
TrainTaskConfig.use_gpu = args.use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode)
if args.is_dist:
......
......@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase):
def test_dist_train(self):
download_files()
self.check_with_place("dist_transformer.py", delta=1e-5)
self.check_with_place(
"dist_transformer.py", delta=1e-5, check_error_log=False)
class TestDistTransformer2x2Async(TestDistBase):
......@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase):
def test_dist_train(self):
download_files()
self.check_with_place("dist_transformer.py", delta=1.0)
self.check_with_place(
"dist_transformer.py", delta=1.0, check_error_log=False)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册