未验证 提交 bc5f0246 编写于 作者: T tangwei12 提交者: GitHub

large scale kv speedup (#26510)

* rename communicator meet->BatchesCounter

* fix parame recv for sparse

* geo sparse init from pserver

* optimize init from pserver

* add large scale optimizer fuse(SGD/ADAM)

* rectification init_worker and exe.run startup program
上级 d7b7dcd1
......@@ -74,8 +74,12 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} else {
recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_));
}
InitParams();
}
void AsyncCommunicator::InitParams() { RecvNoBarrier(); }
AsyncCommunicator::~AsyncCommunicator() {
running_ = false;
if (main_thread_) main_thread_->join();
......@@ -157,16 +161,18 @@ void AsyncCommunicator::MainThread() {
}
while (running_) {
int meet = Meet();
VLOG(1) << "async_meet: " << meet;
SendGlobalStep(meet);
SendByCommunicator(meet);
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
int batches = BatchesCounter();
if (batches > 0) {
SendGlobalStep(batches);
SendByCommunicator(batches);
BarrierSend();
RecvByCommunicator();
BarrierRecv();
BarrierWeakUp();
} else {
VLOG(1) << "get nothing from sending queue, will skip send/recv";
}
}
VLOG(1) << "communicator stopped, send thread exit";
}
......@@ -187,7 +193,7 @@ void AsyncCommunicator::RecvNoBarrier() {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_, false);
recv_functor(iter.second, *recv_scope_);
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
......@@ -197,7 +203,7 @@ void AsyncCommunicator::RecvNoBarrier() {
}
}
int AsyncCommunicator::Meet() {
int AsyncCommunicator::BatchesCounter() {
auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER);
size_t merged_var_num = 0;
......@@ -316,7 +322,7 @@ void HalfAsyncCommunicator::Clean() {
}
}
int HalfAsyncCommunicator::Meet() {
int HalfAsyncCommunicator::BatchesCounter() {
while (running_) {
if (barrier_counter_.load() >= barrier_trigger_.load() &&
barrier_trigger_.load() != 0) {
......@@ -443,7 +449,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
old_scope_.reset(new Scope());
pserver_scope_.reset(new Scope());
Init();
InitParams();
}
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
......@@ -626,9 +632,7 @@ void GeoCommunicator::RecvByCommunicator() {
if (recv_ctx.is_sparse) {
RecvSparse(var_name);
} else {
VLOG(1) << "recv dense " << var_name << " begin";
RecvDense(var_name);
VLOG(1) << "recv dense " << var_name << " done";
}
};
tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task)));
......@@ -696,7 +700,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
auto &ctx = recv_varname_to_ctx_.at(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *pserver_scope_, true);
recv(ctx, *pserver_scope_);
PADDLE_ENFORCE_EQ(
var_psrever->IsInitialized(), true,
......@@ -721,7 +725,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) {
t_timestamp->data<float>());
}
void GeoCommunicator::Init() {
void GeoCommunicator::InitParams() {
std::vector<std::future<void>> tasks;
tasks.reserve(recv_varname_to_ctx_.size());
......@@ -744,12 +748,17 @@ void GeoCommunicator::Init() {
}
void GeoCommunicator::InitDense(const std::string varname) {
auto *var = old_scope_->Var(varname);
var->GetMutable<framework::LoDTensor>();
auto &ctx = recv_varname_to_ctx_.at(varname);
auto recv = distributed::ParameterRecv<float>();
recv(ctx, *old_scope_);
recv(ctx, *recv_scope_);
auto *global_var = recv_scope_->FindVar(varname);
global_var->GetMutable<framework::LoDTensor>();
auto *old_var = old_scope_->Var(varname);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
VLOG(1) << "init dense variable " << varname << " done";
}
......@@ -781,22 +790,41 @@ void GeoCommunicator::InitSparse() {
LargeScaleKV::Init(metas);
for (size_t i = 0; i < metas.size(); i++) {
auto &varname = metas[i].name;
auto &dict = dicts[i];
for (auto &meta : metas) {
auto &ctx = recv_varname_to_ctx_.at(meta.name);
auto recv = distributed::ParameterRecv<float>();
std::vector<int64_t> ids;
ids.reserve(dict);
auto *global_var = recv_scope_->FindVar(meta.name);
auto global_value = global_var->Get<framework::LoDTensor>();
auto rows = global_value.dims()[0];
auto dim1 = global_value.dims()[1];
for (auto j = 0; j < dict; ++j) {
ids.push_back(j);
}
recv(ctx, *recv_scope_);
VLOG(1) << "recv " << meta.name << " with global scope for init";
auto n_rows = global_var->Get<framework::LoDTensor>().dims()[0];
PADDLE_ENFORCE_EQ(
rows, n_rows,
platform::errors::InvalidArgument(
"global var: %s origin dim must equal recved rows", meta.name));
std::vector<int64_t> ids(rows);
std::iota(ids.begin(), ids.end(), 0);
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Init(ids);
std::vector<std::vector<std::vector<float> *>> values;
ins->Get(meta.name)->Init(ids);
ins->Get(meta.name)->Get(ids, {"Param"}, &values);
VLOG(3) << "GeoCommunicator init sparse " << varname << " with size "
<< ids.size();
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(
paddle::platform::CPUDeviceContext());
for (auto &id : ids) {
blas.VCOPY(dim1, global_value.data<float>() + id * dim1,
values[id][0]->data());
}
}
VLOG(3) << "init sparse variable done";
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
......@@ -29,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
......@@ -279,6 +281,8 @@ class AsyncCommunicator : public Communicator {
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void InitParams();
void MainThread();
void Send(const std::vector<std::string> &var_names,
......@@ -293,7 +297,7 @@ class AsyncCommunicator : public Communicator {
virtual void RecvNoBarrier();
virtual int Meet();
virtual int BatchesCounter();
virtual void BarrierSend() {}
......@@ -350,7 +354,7 @@ class HalfAsyncCommunicator : public AsyncCommunicator {
void BarrierTriggerReset(int initial_val) override;
int Meet();
int BatchesCounter();
void BarrierWeakUp();
......@@ -435,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator {
void RecvDense(const std::string &varname);
void Init();
void InitParams();
void InitSparse();
......
......@@ -41,8 +41,67 @@ using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void RecvSelectedRows(const CommContext &rpc_ctx,
const framework::Scope &scope) {
void RecvSparseLodTensor(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<const float *> tensors;
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *local_var = local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVarNoBarrier(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name));
const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value);
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
auto *merged_var = scope.FindVar(rpc_ctx.var_name);
if (merged_var == nullptr || !merged_var->IsInitialized()) {
PADDLE_THROW(
platform::errors::InvalidArgument("%s must initialized at first."));
}
auto dims1 = merged_var->Get<framework::LoDTensor>().dims()[1];
int64_t height = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto *splited_var = local_scope->FindVar(rpc_ctx.splited_varnames[i]);
height += splited_var->Get<framework::LoDTensor>().dims()[0];
}
PADDLE_ENFORCE_EQ(merged_var->Get<framework::LoDTensor>().dims()[0], height,
"recved var must has same dims with local var");
auto *merged_t = merged_var->GetMutable<framework::LoDTensor>();
auto *merged_d = merged_t->mutable_data<float>(cpu_place);
auto pserver_num = rpc_ctx.splited_varnames.size();
for (int x = 0; x < height; ++x) {
auto id = x % pserver_num;
auto idx = x / pserver_num;
std::memcpy(merged_d + x * dims1, tensors[id] + idx * dims1,
sizeof(float) * dims1);
}
}
template <typename T>
void RecvGeoSparseRecords(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
......@@ -84,9 +143,14 @@ void RecvSelectedRows(const CommContext &rpc_ctx,
ids_num += recv_t.rows().size();
width = recv_t.value().dims()[1];
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids),
[&](int64_t id) { return id * pserver_num + i; });
if (rpc_ctx.is_distributed) {
std::copy(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids));
} else {
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids),
[&](int64_t id) { return id * pserver_num + i; });
}
}
auto *var = scope.FindVar(rpc_ctx.var_name);
......@@ -146,7 +210,8 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool barrier) {
const framework::Scope &scope,
bool geo_records) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
......@@ -154,18 +219,21 @@ void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
"origin_varnames.size() >= 1 is permitted"));
if (rpc_ctx.is_sparse) {
RecvSelectedRows<T>(rpc_ctx, scope);
if (geo_records) {
RecvGeoSparseRecords<T>(rpc_ctx, scope);
} else {
RecvSparseLodTensor<T>(rpc_ctx, scope);
}
} else {
RecvLodTensor<T>(rpc_ctx, scope);
}
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, true);
this->operator()(rpc_ctx, scope, false);
}
template struct ParameterRecv<float>;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h"
#include <string>
namespace paddle {
namespace operators {
class LargeScaleFuseAdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of LargeScaleFuseAdamOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
"Input(LearningRate) of LargeScaleFuseAdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad");
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "LearningRate") {
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class LargeScaleFuseAdamOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto in_var_type = ctx->GetInputType("Grad");
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true, platform::errors::InvalidArgument(
"The input Var's type should be LoDtensor or "
"SelectedRows, but the received type is %s",
in_var_type));
}
};
class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W.");
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddAttr<float>("beta1",
"(float, default 0.9) "
"Exponential decay rate for the "
"first moment estimates.")
.SetDefault(0.9f);
AddAttr<float>("beta2",
"(float, default 0.999) "
"exponential decay rate for the "
"second moment estimates.")
.SetDefault(0.999f);
AddAttr<float>("epsilon",
"(float, default 1.0e-8) "
"Constant for numerical stability")
.SetDefault(1.0e-8f);
AddAttr<bool>("is_entry",
"(bool)"
"sparse table need entry");
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddAttr<std::vector<std::string>>("value_names",
"(strings)"
"sparse table name");
AddComment(R"DOC(
Adam Optimizer.
This implements the Adam optimizer from Section 2 of the Adam
paper : https://arxiv.org/abs/1412.6980.
Adam is a first-order gradient-based optimization method based on
adaptive estimates of lower-order moments.
Adam updates:
$$
moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\
moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\
learning\_rate = learning\_rate *
\frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\
param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon}
$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_fuse_adam, ops::LargeScaleFuseAdamOp,
ops::LargeScaleFuseAdamOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LargeScaleFuseAdamOpInferVarType);
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_adam,
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h> // for sqrt in CPU and CUDA
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LargeScaleFuseAdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename T>
class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using paddle::framework::LoDTensor;
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(
grad_var->IsType<framework::SelectedRows>(),
platform::errors::InvalidArgument(
"in large scale optimize, gradient should only be SelectedRows"));
const auto &grad = grad_var->Get<framework::SelectedRows>();
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad.rows().size() == 0) {
return;
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
std::vector<int64_t> in_rows;
in_rows.reserve(grad_merge_ptr->rows().size());
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
std::back_inserter(in_rows));
const auto *lr = learning_rate->data<T>();
auto grad_v = grad_merge_ptr->value();
auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry");
auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
auto *beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto *beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
auto *beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto *beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
platform::errors::InvalidArgument(
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
// update beta1 and beta2
beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims;
auto *ins = distributed::LargeScaleKV::GetInstance();
auto *table = ins->Get(tablename);
table->Get(in_rows, value_names, &values);
table->Dims({"Param"}, &dims);
PADDLE_ENFORCE_EQ(dims[0], grad_width,
platform::errors::InvalidArgument(
"param_row should have the same size with grad_row"));
T lr_ = lr[0];
T beta1_pow_ = beta1_pow->data<T>()[0];
T beta2_pow_ = beta2_pow->data<T>()[0];
lr_ *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
for (size_t i = 0; i < in_rows.size(); i++) {
auto &params = values[i][0];
auto &moment_1 = values[i][1];
auto &moment_2 = values[i][2];
auto *p_data = params->data();
auto *m1_data = moment_1->data();
auto *m2_data = moment_2->data();
for (int x = 0; x < grad_width; ++x) {
auto g = grad_v.data<T>()[grad_width * i + x];
m1_data[x] = beta1 * m1_data[x] + (1 - beta1) * g;
m2_data[x] = beta2 * m2_data[x] + (1 - beta2) * g * g;
p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
}
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h"
#include <string>
namespace paddle {
namespace operators {
class LargeScaleFuseSGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of LargeScaleFuseSGDOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
"Input(LearningRate) of LargeScaleFuseSGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad");
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "LearningRate") {
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class LargeScaleFuseSGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto in_var_type = ctx->GetInputType("Grad");
PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
true, platform::errors::InvalidArgument(
"The input Var's type should be LoDtensor or "
"SelectedRows, but the received type is %s",
in_var_type));
}
};
class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W.");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddAttr<bool>("is_entry",
"(bool)"
"sparse table need entry");
AddAttr<std::string>("tablename",
"(string)"
"sparse table name");
AddAttr<std::vector<std::string>>("value_names",
"(strings)"
"sparse table name");
AddComment(R"DOC(
LargeScaleFuseSGD operator
This operator implements one step of the stochastic gradient descent algorithm.
$$param\_out = param - learning\_rate * grad$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_sparse_table_fuse_sgd, ops::LargeScaleFuseSGDOp,
ops::LargeScaleFuseSGDOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LargeScaleFuseSGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_sgd,
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LargeScaleFuseSGDOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename T>
class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(
grad_var->IsType<framework::SelectedRows>(),
platform::errors::InvalidArgument(
"in large scale optimize, gradient should only be SelectedRows"));
const auto &grad = grad_var->Get<framework::SelectedRows>();
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad.rows().size() == 0) {
return;
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
std::vector<int64_t> in_rows;
in_rows.reserve(grad_merge_ptr->rows().size());
std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(),
std::back_inserter(in_rows));
const auto *lr = learning_rate->data<T>();
auto grad_v = grad_merge_ptr->value();
auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry");
auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims;
auto *ins = distributed::LargeScaleKV::GetInstance();
auto *table = ins->Get(tablename);
table->Get(in_rows, value_names, &values);
table->Dims({"Param"}, &dims);
PADDLE_ENFORCE_EQ(dims[0], grad_width,
platform::errors::InvalidArgument(
"param_row should have the same size with grad_row"));
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
std::vector<T> grads;
framework::TensorToVector(grad_v, ctx.device_context(), &grads);
blas.SCAL(grads.size(), lr[0], grads.data());
for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) {
auto &params = values[x][0];
blas.VSUB(grad_width, params->data(), grads.data() + grad_width * x,
params->data());
}
}
};
} // namespace operators
} // namespace paddle
......@@ -37,12 +37,6 @@ class RecvOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
int do_not_run = Attr<int>("do_not_run");
if (do_not_run) {
VLOG(3) << "recv do not run!";
return;
}
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
......@@ -63,11 +57,10 @@ class RecvOp : public framework::OperatorBase {
if (recv_varnames.size() > 0) {
auto *communicator = distributed::Communicator::GetInstance();
if (communicator == nullptr) {
if (communicator != nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"need run fleet.init_worker first"));
"execute startup program must before fleet.init_worker"));
}
communicator->RecvNoBarrier();
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
......
......@@ -220,12 +220,12 @@ class ParameterServerRuntime(RuntimeBase):
else:
model_dirname = None
if self.role_maker._is_heter_worker():
self._init_worker()
executor = self._get_executor()
executor.run(fluid.default_startup_program())
if self.role_maker._is_heter_worker():
self._init_worker()
if self.role_maker._is_heter_worker():
return
......
......@@ -191,12 +191,14 @@ class FleetTranspiler(Fleet):
self._communicator = Communicator(
trainer_config.mode, kwargs,
trainer_config.get_communicator_flags())
self._communicator.init_with_ctx(send_ctx, recv_ctx)
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
raise ValueError(
"Communicator can only be inited once, please check")
def init_worker(self):
"""
......
......@@ -624,6 +624,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
value_dims = []
grad = None
opt_idx = -1
fuse = False
for op in block.ops:
opt_idx += 1
......@@ -631,6 +632,9 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type not in opt_value_map.keys():
continue
if op.type in ["sgd", "adam"]:
fuse = True
grad = main_program.global_block().vars[op.input("Grad")[0]]
for value in opt_value_map[op.type]:
......@@ -644,7 +648,67 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if value_names:
break
return grad, opt_idx, value_names, value_dims, acture_names
return grad, opt_idx, value_names, value_dims, acture_names, fuse
def add_fuse_large_scale_op(block, global_block, table_name, value_names,
acture_names, grad, is_entry, opt_idx):
op = block.ops[opt_idx]
if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": grad,
"LearningRate": lr},
attrs={
"is_entry": is_entry,
"tablename": table_name,
"value_names": value_names
})
elif op.type == "adam":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
beta1_pow = main_program.global_block().vars[op.input("Beta1Pow")[
0]]
beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[
0]]
beta1_pow_o = main_program.global_block().vars[op.output(
"Beta1PowOut")[0]]
beta2_pow_o = main_program.global_block().vars[op.output(
"Beta2PowOut")[0]]
beta1 = op.attr('beta1')
beta2 = op.attr('beta2')
epsilon = op.attr('epsilon')
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_adam",
inputs={
"Grad": grad,
"LearningRate": lr,
"Beta1Pow": beta1_pow,
"Beta2Pow": beta2_pow
},
outputs={
"Beta1PowOut": beta1_pow_o,
"Beta2PowOut": beta2_pow_o
},
attrs={
"beta1": beta1,
"beta2": beta2,
"epsilon": epsilon,
"is_entry": is_entry,
"tablename": table_name,
"value_names": value_names
})
else:
raise ValueError("only support sgd/adam optimizer now")
def add_large_scale_op(block, global_block, table_name, value_names,
acture_names, grad, is_entry, opt_idx):
......@@ -711,24 +775,35 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
for param, blockid in param_blockid_map.items():
opt_block = program.block(blockid)
grad, opt_idx, value_names, value_dims, acture_names = \
grad, opt_idx, value_names, value_dims, acture_names, fuse = \
get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
is_entry = False if entry_attr == "none" else True
add_large_scale_op(opt_block,
program.global_block(), param, value_names,
acture_names, grad, is_entry, opt_idx)
if fuse:
add_fuse_large_scale_op(opt_block,
program.global_block(), param,
value_names, acture_names, grad,
is_entry, opt_idx)
else:
add_large_scale_op(opt_block,
program.global_block(), param, value_names,
acture_names, grad, is_entry, opt_idx)
else:
large_scale_kv_metas = []
for param, blockid in param_blockid_map.items():
opt_block = main_program.block(blockid)
grad, _, value_names, value_dims, acture_names = \
grad, opt_idx, value_names, value_dims, acture_names, fuse = \
get_optimizer_values(opt_block)
entry_attr = get_entry_attr(param)
if fuse:
# remove origin optimzier op
opt_block._remove_op(opt_idx)
# training/infer
mode = "0"
names_str = ",".join(value_names)
......
......@@ -227,22 +227,6 @@ def init_from_server_pass(program, config):
fetch_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
recv_ctx = config.get_communicator_recv_context(recv_type=1)
recv_varnames = []
for name, ctxs in recv_ctx.items():
recv_varnames.extend(ctxs.origin_varnames())
program.global_block().append_op(
type="recv",
inputs={"X": []},
outputs={"Out": []},
attrs={
"recv_varnames": recv_varnames,
"trainer_id": config.get_role_id(),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
program.global_block().append_op(
type="fetch_barrier",
inputs={},
......
......@@ -164,8 +164,8 @@ def train(args):
elif fleet.is_worker():
logger.info("run trainer")
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
thread_num = 2
filelist = []
......
......@@ -163,8 +163,10 @@ class TestDistCTR2x2(FleetDistRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
......@@ -202,8 +204,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
thread_num = 2
batch_size = 128
......
......@@ -60,8 +60,9 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
......@@ -104,8 +105,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
thread_num = 2
batch_size = 128
......
......@@ -152,8 +152,9 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
......@@ -176,8 +177,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128
......
......@@ -222,8 +222,8 @@ class TestDistSimnetBow2x2(FleetDistRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
# reader
train_reader = paddle.batch(fake_simnet_reader(), batch_size=batch_size)
......
......@@ -151,8 +151,9 @@ class TestDistCTR2x2(FleetDistRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
......
......@@ -30,11 +30,10 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distribu
class TestCommunicator(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
cost = fluid.layers.square_error_cost(input=x, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
......
......@@ -83,8 +83,8 @@ class TestCommunicatorGeoEnd2End(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, z, y])
......
......@@ -71,8 +71,8 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
......
......@@ -27,11 +27,9 @@ import paddle.distributed.fleet as fleet
class TestCommunicator(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
cost = fluid.layers.square_error_cost(input=x, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
......
......@@ -44,16 +44,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32')
y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = paddle.fluid.layers.square_error_cost(input=x, label=y)
avg_cost = paddle.fluid.layers.mean(cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
......@@ -71,7 +66,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
sends += 1
if op.type == "sgd":
sgds += 1
self.assertEqual(sends, 7)
self.assertEqual(sends, 1)
self.assertEqual(sgds, 0)
fleet.init_worker()
......@@ -89,16 +84,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
paddle.fluid.framework.switch_startup_program(startup_program)
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32')
y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = paddle.fluid.layers.square_error_cost(input=x, label=y)
avg_cost = paddle.fluid.layers.mean(cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
......
......@@ -36,16 +36,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
def test_gradient_merge_optimizer(self):
fleet.init(role_maker.PaddleCloudRoleMaker())
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32')
y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = paddle.fluid.layers.square_error_cost(input=x, label=y)
avg_cost = paddle.fluid.layers.mean(cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = False
......@@ -63,7 +58,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase):
sends += 1
if op.type == "sgd":
sgds += 1
self.assertEqual(sends, 6)
self.assertEqual(sends, 0)
self.assertEqual(sgds, 0)
fleet.init_worker()
......
......@@ -70,15 +70,13 @@ class TestPSPassWithBow(unittest.TestCase):
q = fluid.layers.data(
name="query_ids", shape=[1], dtype="int64", lod_level=1)
# embedding
q_emb = fluid.layers.embedding(
q_emb = fluid.contrib.layers.sparse_embedding(
input=q,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr),
is_sparse=is_sparse)
learning_rate=emb_lr))
q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim])
# vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
......@@ -97,15 +95,13 @@ class TestPSPassWithBow(unittest.TestCase):
pt = fluid.layers.data(
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
# embedding
pt_emb = fluid.layers.embedding(
pt_emb = fluid.contrib.layers.sparse_embedding(
input=pt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr),
is_sparse=is_sparse)
learning_rate=emb_lr))
pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim])
# vsum
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
......@@ -123,15 +119,13 @@ class TestPSPassWithBow(unittest.TestCase):
nt = fluid.layers.data(
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
# embedding
nt_emb = fluid.layers.embedding(
nt_emb = fluid.contrib.layers.sparse_embedding(
input=nt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr),
is_sparse=is_sparse)
learning_rate=emb_lr))
nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim])
# vsum
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
......@@ -167,7 +161,7 @@ class TestPSPassWithBow(unittest.TestCase):
fleet.init(role)
loss, acc, _ = self.net()
optimizer = fluid.optimizer.SGD(base_lr)
optimizer = fluid.optimizer.Adam(base_lr)
strategy = StrategyFactory.create_async_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(loss)
......
......@@ -168,12 +168,13 @@ class TestPSPassWithBow(unittest.TestCase):
fleet.init(role)
loss, acc, _ = self.net()
optimizer = fluid.optimizer.SGD(
optimizer = fluid.optimizer.Adagrad(
learning_rate=fluid.layers.exponential_decay(
learning_rate=base_lr,
decay_steps=500,
decay_rate=0.969,
staircase=True))
strategy = StrategyFactory.create_async_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(loss)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
# For Net
base_lr = 0.2
emb_lr = base_lr * 3
dict_dim = 1500
emb_dim = 128
hid_dim = 128
margin = 0.1
sample_rate = 1
batch_size = 4
class TestPSPassWithBow(unittest.TestCase):
def net(self):
def get_acc(cos_q_nt, cos_q_pt, batch_size):
cond = fluid.layers.less_than(cos_q_nt, cos_q_pt)
cond = fluid.layers.cast(cond, dtype='float64')
cond_3 = fluid.layers.reduce_sum(cond)
acc = fluid.layers.elementwise_div(
cond_3,
fluid.layers.fill_constant(
shape=[1], value=batch_size * 1.0, dtype='float64'),
name="simnet_acc")
return acc
def get_loss(cos_q_pt, cos_q_nt):
loss_op1 = fluid.layers.elementwise_sub(
fluid.layers.fill_constant_batch_size_like(
input=cos_q_pt,
shape=[-1, 1],
value=margin,
dtype='float32'),
cos_q_pt)
loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt)
loss_op3 = fluid.layers.elementwise_max(
fluid.layers.fill_constant_batch_size_like(
input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'),
loss_op2)
avg_cost = fluid.layers.mean(loss_op3)
return avg_cost
is_distributed = False
is_sparse = True
# query
q = fluid.layers.data(
name="query_ids", shape=[1], dtype="int64", lod_level=1)
# embedding
q_emb = fluid.contrib.layers.sparse_embedding(
input=q,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr))
q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim])
# vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
q_ss = fluid.layers.softsign(q_sum)
# fc layer after conv
q_fc = fluid.layers.fc(
input=q_ss,
size=hid_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__q_fc__",
learning_rate=base_lr))
# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
# pt
pt = fluid.layers.data(
name="pos_title_ids", shape=[1], dtype="int64", lod_level=1)
# embedding
pt_emb = fluid.contrib.layers.sparse_embedding(
input=pt,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr))
pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim])
# vsum
pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum')
pt_ss = fluid.layers.softsign(pt_sum)
# fc layer
pt_fc = fluid.layers.fc(
input=pt_ss,
size=hid_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__fc__",
learning_rate=base_lr),
bias_attr=fluid.ParamAttr(name="__fc_b__"))
# nt
nt = fluid.layers.data(
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
# embedding
nt_emb = fluid.contrib.layers.sparse_embedding(
input=nt,
size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__emb__",
learning_rate=emb_lr))
nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim])
# vsum
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
nt_ss = fluid.layers.softsign(nt_sum)
# fc layer
nt_fc = fluid.layers.fc(
input=nt_ss,
size=hid_dim,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01),
name="__fc__",
learning_rate=base_lr),
bias_attr=fluid.ParamAttr(name="__fc_b__"))
cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc)
cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc)
# loss
avg_cost = get_loss(cos_q_pt, cos_q_nt)
# acc
acc = get_acc(cos_q_nt, cos_q_pt, batch_size)
return [avg_cost, acc, cos_q_pt]
def test(self):
endpoints = [
"127.0.0.1:36004", "127.0.0.1:36005", "127.0.0.1:36006",
"127.0.0.1:36007"
]
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.SERVER,
worker_num=2,
server_endpoints=endpoints)
fleet.init(role)
loss, acc, _ = self.net()
optimizer = fluid.optimizer.Adagrad(base_lr)
strategy = StrategyFactory.create_async_strategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(loss)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestLookupTableFuseOp(unittest.TestCase):
def test_fuse(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
def check_with_place(self, place):
scope = fluid.global_scope()
scope.var("LearningRate").get_tensor().set([0.01], place)
scope.var("Ids").get_tensor().set([i for i in range(100)], place)
init_program = fluid.Program()
lr = init_program.global_block().create_var(
name="LearningRate",
persistable=True,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
dtype="float32")
ids = init_program.global_block().create_var(
name="Ids",
persistable=True,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[100],
dtype="int64")
output = init_program.global_block().create_var(
name="output",
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[100, 8],
dtype="float32")
metas = []
metas.append(
"embedding_1.block0:Param,Moment1,Moment2:8,8,8:0:embedding_1@GRAD.block0:embedding_1.block0,embedding_1_moment1_0,embedding_1_moment2_0,kSparseIDs@embedding_1.block0:uniform_random&0&-0.5&0.5,fill_constant&0.0,fill_constant&0.0:none"
)
metas.append(
"embedding_2.block0:Param:8:0:embedding_2@GRAD.block0:embedding_2.block0,kSparseIDs@embedding_2.block0:uniform_random&0&-0.5&0.5:none"
)
init_program.global_block().append_op(
type="lookup_sparse_table_init",
inputs=None,
outputs=None,
attrs={"large_scale_metas": metas})
init_program.global_block().append_op(
type="lookup_sparse_table_read",
inputs={"Ids": ids},
outputs={"Out": output},
attrs={
"tablename": "embedding_1.block0",
"init": True,
"value_names": ["Param"],
})
init_program.global_block().append_op(
type="lookup_sparse_table_read",
inputs={"Ids": ids},
outputs={"Out": output},
attrs={
"tablename": "embedding_2.block0",
"init": True,
"value_names": ["Param"],
})
executor = fluid.Executor(place)
executor.run(init_program)
training_program = fluid.Program()
scope.var('Beta1Pow').get_tensor().set(
np.array([0]).astype("float32"), place)
scope.var('Beta2Pow').get_tensor().set(
np.array([0]).astype("float32"), place)
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 8
w_selected_rows = scope.var('Grad').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
lr = training_program.global_block().create_var(
name="LearningRate",
persistable=True,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
dtype="float32")
grads = training_program.global_block().create_var(
name="Grad",
persistable=True,
type=fluid.core.VarDesc.VarType.SELECTED_ROWS,
shape=[100, 8],
dtype="float32")
beta1 = training_program.global_block().create_var(
name="Beta1Pow",
persistable=True,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
dtype="float32")
beta2 = training_program.global_block().create_var(
name="Beta2Pow",
persistable=True,
type=fluid.core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
dtype="float32")
training_program.global_block().append_op(
type="lookup_sparse_table_fuse_adam",
inputs={
"Grad": grads,
"LearningRate": lr,
"Beta1Pow": beta1,
"Beta2Pow": beta2,
},
outputs={"Beta1PowOut": beta1,
"Beta2PowOut": beta2},
attrs={
"is_entry": False,
"tablename": "embedding_1.block0",
"value_names": ["Param", "Moment1", "Moment2"],
})
training_program.global_block().append_op(
type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": grads,
"LearningRate": lr},
attrs={
"is_entry": False,
"tablename": "embedding_2.block0",
"value_names": ["Param"],
})
executor.run(training_program)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册