提交 ea25eef3 编写于 作者: Q qiaolongfei

word2vec demo support sparse remote update

上级 f6c5b6fd
......@@ -2,26 +2,38 @@ import math
import paddle.v2 as paddle
dictsize = 1953
embsize = 32
hiddensize = 256
N = 5
def wordemb(inlayer):
wordemb = paddle.layer.table_projection(
wordemb = paddle.layer.embedding(
input=inlayer,
size=embsize,
param_attr=paddle.attr.Param(
name="_proj",
initial_std=0.001,
learning_rate=1,
l2_rate=0, ))
l2_rate=0,
sparse_update=True))
return wordemb
def main():
paddle.init(use_gpu=False, trainer_count=1)
# for local training
cluster_train = False
if not cluster_train:
paddle.init(use_gpu=False, trainer_count=1)
else:
paddle.init(
use_gpu=False,
trainer_count=1,
port=7164,
ports_num=1,
ports_num_for_sparse=1,
num_gradient_servers=1)
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
firstword = paddle.layer.data(
......@@ -65,11 +77,15 @@ def main():
result.metrics)
cost = paddle.layer.classification_cost(input=predictword, label=nextword)
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
adagrad = paddle.optimizer.AdaGrad(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
trainer = paddle.trainer.SGD(cost,
parameters,
adagrad,
is_local=not cluster_train)
trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
num_passes=30,
......
......@@ -821,7 +821,7 @@ public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount,
bool userSparseUpdater);
bool useSparseUpdater);
~ParameterUpdater();
/**
......
......@@ -29,11 +29,11 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater(
}
ParameterUpdater *ParameterUpdater::createRemoteUpdater(
OptimizationConfig *config, int passCount, bool userSparseUpdater) {
OptimizationConfig *config, int passCount, bool useSparseUpdater) {
auto updater = new ParameterUpdater();
auto remoteUpdater = new paddle::RemoteParameterUpdater(
config->m->getConfig(), passCount, nullptr);
if (userSparseUpdater) {
if (useSparseUpdater) {
std::unique_ptr<paddle::ParameterUpdater> remoteUpdaterPtr;
remoteUpdaterPtr.reset(remoteUpdater);
auto sparseRemoteUpdater =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册