提交 22b9b666 编写于 作者: Y Yu Yang

Add unittest to coverage SgdThreadUpdater's enableBufType

上级 dd894c29
......@@ -55,6 +55,9 @@ void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
// not create parameter buf for PARAMETER_GRADIENT for sparse update in
// Parameter::enableType(). But gradient parameter buf is still used
// in SgdThreadUpdater. We need to explicitly create it.
//
// The AverageOptimizer::restore/apply method will use PARAMETER_GRADIENT
// as a temp buffer.
para->enableBufType(PARAMETER_GRADIENT);
}
}
......
from paddle.trainer_config_helpers import *
settings(batch_size=128, learning_method=AdaGradOptimizer(), learning_rate=1e-4)
file_list = 'trainer/tests/fake_file_list.list'
define_py_data_sources2(
train_list=file_list,
test_list=file_list,
module="simple_sparse_neural_network_dp",
obj="process")
embedding = embedding_layer(
input=data_layer(
name="word_ids", size=65536),
size=128,
param_attr=ParamAttr(sparse_update=True))
prediction = fc_layer(input=embedding, size=10, act=SoftmaxActivation())
outputs(
classification_cost(
input=prediction, label=data_layer(
name='label', size=10)))
from paddle.trainer.PyDataProvider2 import provider, integer_sequence, integer_value
import random
def init_hook(settings, is_train, **kwargs):
settings.is_train = is_train
@provider(
input_types={'word_ids': integer_value(65536),
'label': integer_value(10)},
min_pool_size=0,
init_hook=init_hook)
def process(settings, filename):
if settings.is_train:
data_size = 2**20
else:
data_size = 2**10
for _ in xrange(data_size):
yield random.randint(0, 65535), random.randint(0, 9)
......@@ -27,6 +27,9 @@ static const string& configFile1 = "trainer/tests/sample_trainer_config.conf";
static const string& configFile2 =
"trainer/tests/sample_trainer_config_parallel.conf";
static const string& configFileSimpleSparse =
"trainer/tests/simple_sparse_neural_network.py";
DECLARE_bool(use_gpu);
DECLARE_string(config);
DECLARE_int32(gpu_id);
......@@ -298,11 +301,15 @@ TEST(checkRemoteUpdater, cpuDeltaTrainerOldUpdater) {
checkRemoteParameterUpdaterTest(configFile1, false, false, 1, true, 10);
}
TEST(SgdThreadUpdater, simpleSparseNN) {
trainerOnePassTest(configFileSimpleSparse, false, false, 1, 0.5, true);
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
initPython(argc, argv);
gNumDevices = hl_get_device_count();
testing::InitGoogleTest(&argc, argv);
FLAGS_num_passes = 1; // train one pass
FLAGS_saving_period = 100000; // do not save parameteres
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册