提交 dda4217c 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #3296 from luotao1/test_TrainerOnePass

reduce time of test_TrainerOnePass
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
settings(batch_size=128, learning_method=AdaGradOptimizer(), learning_rate=1e-4) settings(batch_size=17, learning_method=AdaGradOptimizer(), learning_rate=1e-4)
file_list = 'trainer/tests/fake_file_list.list' file_list = 'trainer/tests/fake_file_list.list'
...@@ -12,7 +12,7 @@ define_py_data_sources2( ...@@ -12,7 +12,7 @@ define_py_data_sources2(
embedding = embedding_layer( embedding = embedding_layer(
input=data_layer( input=data_layer(
name="word_ids", size=65536), name="word_ids", size=8191),
size=128, size=128,
param_attr=ParamAttr(sparse_update=True)) param_attr=ParamAttr(sparse_update=True))
prediction = fc_layer(input=embedding, size=10, act=SoftmaxActivation()) prediction = fc_layer(input=embedding, size=10, act=SoftmaxActivation())
......
...@@ -7,15 +7,15 @@ def init_hook(settings, is_train, **kwargs): ...@@ -7,15 +7,15 @@ def init_hook(settings, is_train, **kwargs):
@provider( @provider(
input_types={'word_ids': integer_value(65536), input_types={'word_ids': integer_value(8191),
'label': integer_value(10)}, 'label': integer_value(10)},
min_pool_size=0, min_pool_size=0,
init_hook=init_hook) init_hook=init_hook)
def process(settings, filename): def process(settings, filename):
if settings.is_train: if settings.is_train:
data_size = 2**20
else:
data_size = 2**10 data_size = 2**10
else:
data_size = 2**5
for _ in xrange(data_size): for _ in xrange(data_size):
yield random.randint(0, 65535), random.randint(0, 9) yield random.randint(0, 8190), random.randint(0, 9)
...@@ -100,25 +100,25 @@ TEST(average_window, gpu) { ...@@ -100,25 +100,25 @@ TEST(average_window, gpu) {
} }
TEST(average_window, gpu2) { TEST(average_window, gpu2) {
FLAGS_num_passes = 100; FLAGS_num_passes = 20;
trainerOnePassTest(configFile1, true, false, 2, 0.01); trainerOnePassTest(configFile1, true, false, 2, 0.01);
FLAGS_num_passes = 1; FLAGS_num_passes = 1;
} }
TEST(average_window, gpu4) { TEST(average_window, gpu4) {
FLAGS_num_passes = 100; FLAGS_num_passes = 20;
trainerOnePassTest(configFile1, true, false, 4, 0.01); trainerOnePassTest(configFile1, true, false, 4, 0.01);
FLAGS_num_passes = 1; FLAGS_num_passes = 1;
} }
TEST(average_window_cpu, gpu2) { TEST(average_window_cpu, gpu2) {
FLAGS_num_passes = 100; FLAGS_num_passes = 20;
trainerOnePassTest(configFile1, true, false, 2, 0.01, true); trainerOnePassTest(configFile1, true, false, 2, 0.01, true);
FLAGS_num_passes = 1; FLAGS_num_passes = 1;
} }
TEST(average_window_cpu, gpu4) { TEST(average_window_cpu, gpu4) {
FLAGS_num_passes = 100; FLAGS_num_passes = 20;
trainerOnePassTest(configFile1, true, false, 4, 0.01, true); trainerOnePassTest(configFile1, true, false, 4, 0.01, true);
FLAGS_num_passes = 1; FLAGS_num_passes = 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册