From bc5f0246a807728593c889d7924c921e88ffe643 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 23 Sep 2020 14:26:37 +0800 Subject: [PATCH] large scale kv speedup (#26510) * rename communicator meet->BatchesCounter * fix parame recv for sparse * geo sparse init from pserver * optimize init from pserver * add large scale optimizer fuse(SGD/ADAM) * rectification init_worker and exe.run startup program --- .../operators/distributed/communicator.cc | 94 ++++++---- .../operators/distributed/communicator.h | 10 +- .../operators/distributed/parameter_recv.cc | 86 ++++++++- .../lookup_sparse_table_fuse_adam_op.cc | 153 ++++++++++++++++ .../lookup_sparse_table_fuse_adam_op.h | 142 +++++++++++++++ .../lookup_sparse_table_fuse_sgd_op.cc | 120 ++++++++++++ .../lookup_sparse_table_fuse_sgd_op.h | 105 +++++++++++ .../operators/distributed_ops/recv_op.cc | 11 +- .../fleet/runtime/parameter_server_runtime.py | 6 +- .../distribute_transpiler/__init__.py | 4 +- .../fleet/parameter_server/ir/pserver_pass.py | 87 ++++++++- .../fleet/parameter_server/ir/trainer_pass.py | 16 -- .../incubate/fleet/tests/fleet_deep_ctr.py | 2 +- .../fluid/tests/unittests/dist_fleet_ctr.py | 6 +- .../tests/unittests/dist_fleet_ctr_ps_gpu.py | 5 +- .../tests/unittests/dist_fleet_heter_ctr.py | 5 +- .../tests/unittests/dist_fleet_simnet_bow.py | 2 +- .../dist_fleet_sparse_embedding_ctr.py | 3 +- .../unittests/test_communicator_async.py | 5 +- .../tests/unittests/test_communicator_geo.py | 2 +- .../unittests/test_communicator_half_async.py | 2 +- .../tests/unittests/test_communicator_sync.py | 6 +- .../test_dist_fleet_a_sync_optimizer_async.py | 30 +-- .../test_dist_fleet_a_sync_optimizer_sync.py | 15 +- .../tests/unittests/test_dist_fleet_ps4.py | 20 +- .../tests/unittests/test_dist_fleet_ps5.py | 3 +- .../tests/unittests/test_dist_fleet_ps6.py | 168 +++++++++++++++++ .../test_dist_lookup_sparse_table_fuse_ops.py | 171 ++++++++++++++++++ 28 files changed, 1137 insertions(+), 142 deletions(-) create mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.cc create mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h create mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.cc create mode 100644 paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_ps6.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_lookup_sparse_table_fuse_ops.py diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index b2cc9390fa..a0ac82a6f4 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -74,8 +74,12 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, } else { recv_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); } + + InitParams(); } +void AsyncCommunicator::InitParams() { RecvNoBarrier(); } + AsyncCommunicator::~AsyncCommunicator() { running_ = false; if (main_thread_) main_thread_->join(); @@ -157,16 +161,18 @@ void AsyncCommunicator::MainThread() { } while (running_) { - int meet = Meet(); - - VLOG(1) << "async_meet: " << meet; - - SendGlobalStep(meet); - SendByCommunicator(meet); - BarrierSend(); - RecvByCommunicator(); - BarrierRecv(); - BarrierWeakUp(); + int batches = BatchesCounter(); + + if (batches > 0) { + SendGlobalStep(batches); + SendByCommunicator(batches); + BarrierSend(); + RecvByCommunicator(); + BarrierRecv(); + BarrierWeakUp(); + } else { + VLOG(1) << "get nothing from sending queue, will skip send/recv"; + } } VLOG(1) << "communicator stopped, send thread exit"; } @@ -187,7 +193,7 @@ void AsyncCommunicator::RecvNoBarrier() { auto &var_name = iter.first; VLOG(4) << "recv var " << var_name; auto recv_functor = distributed::ParameterRecv(); - recv_functor(iter.second, *recv_scope_, false); + recv_functor(iter.second, *recv_scope_); }; task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task))); } @@ -197,7 +203,7 @@ void AsyncCommunicator::RecvNoBarrier() { } } -int AsyncCommunicator::Meet() { +int AsyncCommunicator::BatchesCounter() { auto &step_queue = send_varname_to_queue_.at(STEP_COUNTER); size_t merged_var_num = 0; @@ -316,7 +322,7 @@ void HalfAsyncCommunicator::Clean() { } } -int HalfAsyncCommunicator::Meet() { +int HalfAsyncCommunicator::BatchesCounter() { while (running_) { if (barrier_counter_.load() >= barrier_trigger_.load() && barrier_trigger_.load() != 0) { @@ -443,7 +449,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, old_scope_.reset(new Scope()); pserver_scope_.reset(new Scope()); - Init(); + InitParams(); } void GeoCommunicator::Send(const std::vector &var_names, @@ -626,9 +632,7 @@ void GeoCommunicator::RecvByCommunicator() { if (recv_ctx.is_sparse) { RecvSparse(var_name); } else { - VLOG(1) << "recv dense " << var_name << " begin"; RecvDense(var_name); - VLOG(1) << "recv dense " << var_name << " done"; } }; tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); @@ -696,7 +700,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) { auto &ctx = recv_varname_to_ctx_.at(varname); auto recv = distributed::ParameterRecv(); - recv(ctx, *pserver_scope_, true); + recv(ctx, *pserver_scope_); PADDLE_ENFORCE_EQ( var_psrever->IsInitialized(), true, @@ -721,7 +725,7 @@ void GeoCommunicator::RecvDense(const std::string &varname) { t_timestamp->data()); } -void GeoCommunicator::Init() { +void GeoCommunicator::InitParams() { std::vector> tasks; tasks.reserve(recv_varname_to_ctx_.size()); @@ -744,12 +748,17 @@ void GeoCommunicator::Init() { } void GeoCommunicator::InitDense(const std::string varname) { - auto *var = old_scope_->Var(varname); - var->GetMutable(); - auto &ctx = recv_varname_to_ctx_.at(varname); auto recv = distributed::ParameterRecv(); - recv(ctx, *old_scope_); + recv(ctx, *recv_scope_); + + auto *global_var = recv_scope_->FindVar(varname); + global_var->GetMutable(); + + auto *old_var = old_scope_->Var(varname); + old_var->GetMutable(); + + framework::CopyVariable(*global_var, old_var); VLOG(1) << "init dense variable " << varname << " done"; } @@ -781,22 +790,41 @@ void GeoCommunicator::InitSparse() { LargeScaleKV::Init(metas); - for (size_t i = 0; i < metas.size(); i++) { - auto &varname = metas[i].name; - auto &dict = dicts[i]; + for (auto &meta : metas) { + auto &ctx = recv_varname_to_ctx_.at(meta.name); + auto recv = distributed::ParameterRecv(); - std::vector ids; - ids.reserve(dict); + auto *global_var = recv_scope_->FindVar(meta.name); + auto global_value = global_var->Get(); + auto rows = global_value.dims()[0]; + auto dim1 = global_value.dims()[1]; - for (auto j = 0; j < dict; ++j) { - ids.push_back(j); - } + recv(ctx, *recv_scope_); + VLOG(1) << "recv " << meta.name << " with global scope for init"; + + auto n_rows = global_var->Get().dims()[0]; + + PADDLE_ENFORCE_EQ( + rows, n_rows, + platform::errors::InvalidArgument( + "global var: %s origin dim must equal recved rows", meta.name)); + + std::vector ids(rows); + std::iota(ids.begin(), ids.end(), 0); auto *ins = distributed::LargeScaleKV::GetInstance(); - ins->Get(varname)->Init(ids); + std::vector *>> values; + + ins->Get(meta.name)->Init(ids); + ins->Get(meta.name)->Get(ids, {"Param"}, &values); - VLOG(3) << "GeoCommunicator init sparse " << varname << " with size " - << ids.size(); + auto blas = math::GetBlas( + paddle::platform::CPUDeviceContext()); + + for (auto &id : ids) { + blas.VCOPY(dim1, global_value.data() + id * dim1, + values[id][0]->data()); + } } VLOG(3) << "init sparse variable done"; diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 2f6da150d1..4a9a9eb170 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/distributed/communicator_common.h" #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/large_scale_kv.h" @@ -279,6 +281,8 @@ class AsyncCommunicator : public Communicator { const RpcCtxMap &recv_varname_to_ctx, Scope *recv_scope) override; + void InitParams(); + void MainThread(); void Send(const std::vector &var_names, @@ -293,7 +297,7 @@ class AsyncCommunicator : public Communicator { virtual void RecvNoBarrier(); - virtual int Meet(); + virtual int BatchesCounter(); virtual void BarrierSend() {} @@ -350,7 +354,7 @@ class HalfAsyncCommunicator : public AsyncCommunicator { void BarrierTriggerReset(int initial_val) override; - int Meet(); + int BatchesCounter(); void BarrierWeakUp(); @@ -435,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator { void RecvDense(const std::string &varname); - void Init(); + void InitParams(); void InitSparse(); diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 5409ec5498..3b8479c91b 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -41,8 +41,67 @@ using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; template -void RecvSelectedRows(const CommContext &rpc_ctx, - const framework::Scope &scope) { +void RecvSparseLodTensor(const CommContext &rpc_ctx, + const framework::Scope &scope) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto cpu_place = platform::CPUPlace(); + auto &cpu_ctx = *pool.Get(cpu_place); + + distributed::RPCClient *rpc_client = + distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); + + std::unique_ptr local_scope = scope.NewTmpScope(); + std::vector tensors; + std::vector rets; + for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { + auto &recv_var_name = rpc_ctx.splited_varnames[i]; + auto *local_var = local_scope->Var(recv_var_name); + VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; + // sparse param in recv_scope is LoDTensor + rets.push_back(rpc_client->AsyncGetVarNoBarrier( + rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, + recv_var_name)); + + const auto *value = local_var->Get().data(); + tensors.push_back(value); + } + + for (size_t i = 0; i < rets.size(); i++) { + PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( + "internal error in RPCClient")); + } + + auto *merged_var = scope.FindVar(rpc_ctx.var_name); + + if (merged_var == nullptr || !merged_var->IsInitialized()) { + PADDLE_THROW( + platform::errors::InvalidArgument("%s must initialized at first.")); + } + auto dims1 = merged_var->Get().dims()[1]; + int64_t height = 0; + for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) { + auto *splited_var = local_scope->FindVar(rpc_ctx.splited_varnames[i]); + height += splited_var->Get().dims()[0]; + } + + PADDLE_ENFORCE_EQ(merged_var->Get().dims()[0], height, + "recved var must has same dims with local var"); + + auto *merged_t = merged_var->GetMutable(); + auto *merged_d = merged_t->mutable_data(cpu_place); + + auto pserver_num = rpc_ctx.splited_varnames.size(); + for (int x = 0; x < height; ++x) { + auto id = x % pserver_num; + auto idx = x / pserver_num; + std::memcpy(merged_d + x * dims1, tensors[id] + idx * dims1, + sizeof(float) * dims1); + } +} + +template +void RecvGeoSparseRecords(const CommContext &rpc_ctx, + const framework::Scope &scope) { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto cpu_place = platform::CPUPlace(); auto &cpu_ctx = *pool.Get(cpu_place); @@ -84,9 +143,14 @@ void RecvSelectedRows(const CommContext &rpc_ctx, ids_num += recv_t.rows().size(); width = recv_t.value().dims()[1]; - std::transform(recv_t.rows().begin(), recv_t.rows().end(), - std::back_inserter(all_ids), - [&](int64_t id) { return id * pserver_num + i; }); + if (rpc_ctx.is_distributed) { + std::copy(recv_t.rows().begin(), recv_t.rows().end(), + std::back_inserter(all_ids)); + } else { + std::transform(recv_t.rows().begin(), recv_t.rows().end(), + std::back_inserter(all_ids), + [&](int64_t id) { return id * pserver_num + i; }); + } } auto *var = scope.FindVar(rpc_ctx.var_name); @@ -146,7 +210,8 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { template void ParameterRecv::operator()(const CommContext &rpc_ctx, - const framework::Scope &scope, bool barrier) { + const framework::Scope &scope, + bool geo_records) { VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name; PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1, @@ -154,18 +219,21 @@ void ParameterRecv::operator()(const CommContext &rpc_ctx, "origin_varnames.size() >= 1 is permitted")); if (rpc_ctx.is_sparse) { - RecvSelectedRows(rpc_ctx, scope); + if (geo_records) { + RecvGeoSparseRecords(rpc_ctx, scope); + } else { + RecvSparseLodTensor(rpc_ctx, scope); + } } else { RecvLodTensor(rpc_ctx, scope); } VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name; } - template void ParameterRecv::operator()(const CommContext &rpc_ctx, const framework::Scope &scope) { - this->operator()(rpc_ctx, scope, true); + this->operator()(rpc_ctx, scope, false); } template struct ParameterRecv; diff --git a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.cc b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.cc new file mode 100644 index 0000000000..e53ce8cc67 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h" + +#include +namespace paddle { +namespace operators { + +class LargeScaleFuseAdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of LargeScaleFuseAdamOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("LearningRate"), + "Input(LearningRate) of LargeScaleFuseAdamOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + + PADDLE_ENFORCE_NE(framework::product(lr_dims), 0, + "Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function."); + + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 element"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad"); + return framework::OpKernelType(data_type, ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (var_name == "LearningRate") { + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class LargeScaleFuseAdamOpInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto in_var_type = ctx->GetInputType("Grad"); + PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + true, platform::errors::InvalidArgument( + "The input Var's type should be LoDtensor or " + "SelectedRows, but the received type is %s", + in_var_type)); + } +}; + +class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Grad", + "(SelectedRows) Ids's type should be SelectedRows" + "THe ids to be looked up in W."); + + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); + AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); + AddInput("LearningRate", "(Tensor) Learning rate of SGD"); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); + AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f); + + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f); + + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f); + + AddAttr("is_entry", + "(bool)" + "sparse table need entry"); + + AddAttr("tablename", + "(string)" + "sparse table name"); + + AddAttr>("value_names", + "(strings)" + "sparse table name"); + + AddComment(R"DOC( +Adam Optimizer. + +This implements the Adam optimizer from Section 2 of the Adam +paper : https://arxiv.org/abs/1412.6980. +Adam is a first-order gradient-based optimization method based on +adaptive estimates of lower-order moments. + +Adam updates: + +$$ +moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ +moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ +learning\_rate = learning\_rate * + \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ +param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} +$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + lookup_sparse_table_fuse_adam, ops::LargeScaleFuseAdamOp, + ops::LargeScaleFuseAdamOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::LargeScaleFuseAdamOpInferVarType); + +REGISTER_OP_CPU_KERNEL( + lookup_sparse_table_fuse_adam, + ops::LargeScaleFuseAdamOpKernel); diff --git a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h new file mode 100644 index 0000000000..89b8d54a46 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_adam_op.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include // for sqrt in CPU and CUDA +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/distributed/large_scale_kv.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { + +template +class LargeScaleFuseAdamOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override; +}; + +template +class LargeScaleFuseAdamOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using paddle::framework::LoDTensor; + + const auto *learning_rate = ctx.Input("LearningRate"); + const auto *grad_var = ctx.InputVar("Grad"); + + PADDLE_ENFORCE( + grad_var->IsType(), + platform::errors::InvalidArgument( + "in large scale optimize, gradient should only be SelectedRows")); + + const auto &grad = grad_var->Get(); + + // for distributed training, a sparse var may be empty, + // just skip updating. + if (grad.rows().size() == 0) { + return; + } + + framework::SelectedRows tmp_grad_merge; + const framework::SelectedRows *grad_merge_ptr; + math::scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), grad, + &tmp_grad_merge, true); + grad_merge_ptr = &tmp_grad_merge; + + std::vector in_rows; + in_rows.reserve(grad_merge_ptr->rows().size()); + std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(), + std::back_inserter(in_rows)); + + const auto *lr = learning_rate->data(); + auto grad_v = grad_merge_ptr->value(); + auto grad_width = grad_v.dims()[1]; + + // auto is_entry = context.Attr("is_entry"); + auto tablename = ctx.Attr("tablename"); + auto value_names = ctx.Attr>("value_names"); + + auto *beta1_pow = ctx.Input("Beta1Pow"); + auto *beta2_pow = ctx.Input("Beta2Pow"); + auto *beta1_pow_out = ctx.Output("Beta1PowOut"); + auto *beta2_pow_out = ctx.Output("Beta2PowOut"); + T epsilon = static_cast(ctx.Attr("epsilon")); + T beta1 = static_cast(ctx.Attr("beta1")); + T beta2 = static_cast(ctx.Attr("beta2")); + + PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1, + platform::errors::InvalidArgument( + "beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + // update beta1 and beta2 + beta1_pow_out->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow->data()[0]; + + std::vector *>> values; + std::vector dims; + + auto *ins = distributed::LargeScaleKV::GetInstance(); + auto *table = ins->Get(tablename); + table->Get(in_rows, value_names, &values); + table->Dims({"Param"}, &dims); + + PADDLE_ENFORCE_EQ(dims[0], grad_width, + platform::errors::InvalidArgument( + "param_row should have the same size with grad_row")); + + T lr_ = lr[0]; + T beta1_pow_ = beta1_pow->data()[0]; + T beta2_pow_ = beta2_pow->data()[0]; + + lr_ *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_); + + for (size_t i = 0; i < in_rows.size(); i++) { + auto ¶ms = values[i][0]; + auto &moment_1 = values[i][1]; + auto &moment_2 = values[i][2]; + + auto *p_data = params->data(); + auto *m1_data = moment_1->data(); + auto *m2_data = moment_2->data(); + + for (int x = 0; x < grad_width; ++x) { + auto g = grad_v.data()[grad_width * i + x]; + m1_data[x] = beta1 * m1_data[x] + (1 - beta1) * g; + m2_data[x] = beta2 * m2_data[x] + (1 - beta2) * g * g; + p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon)); + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.cc b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.cc new file mode 100644 index 0000000000..010658b528 --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h" + +#include +namespace paddle { +namespace operators { + +class LargeScaleFuseSGDOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of LargeScaleFuseSGDOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("LearningRate"), + "Input(LearningRate) of LargeScaleFuseSGDOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + + PADDLE_ENFORCE_NE(framework::product(lr_dims), 0, + "Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function."); + + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 element"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Grad"); + return framework::OpKernelType(data_type, ctx.device_context()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (var_name == "LearningRate") { + return framework::OpKernelType(tensor.type(), tensor.place(), + tensor.layout()); + } + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +class LargeScaleFuseSGDOpInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto in_var_type = ctx->GetInputType("Grad"); + PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + true, platform::errors::InvalidArgument( + "The input Var's type should be LoDtensor or " + "SelectedRows, but the received type is %s", + in_var_type)); + } +}; + +class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Grad", + "(SelectedRows) Ids's type should be SelectedRows" + "THe ids to be looked up in W."); + AddInput("LearningRate", "(Tensor) Learning rate of SGD"); + AddAttr("is_entry", + "(bool)" + "sparse table need entry"); + + AddAttr("tablename", + "(string)" + "sparse table name"); + + AddAttr>("value_names", + "(strings)" + "sparse table name"); + + AddComment(R"DOC( + +LargeScaleFuseSGD operator + +This operator implements one step of the stochastic gradient descent algorithm. + +$$param\_out = param - learning\_rate * grad$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + lookup_sparse_table_fuse_sgd, ops::LargeScaleFuseSGDOp, + ops::LargeScaleFuseSGDOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::LargeScaleFuseSGDOpInferVarType); + +REGISTER_OP_CPU_KERNEL( + lookup_sparse_table_fuse_sgd, + ops::LargeScaleFuseSGDOpKernel); diff --git a/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h new file mode 100644 index 0000000000..5d4bf1015f --- /dev/null +++ b/paddle/fluid/operators/distributed_ops/lookup_sparse_table_fuse_sgd_op.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/distributed/large_scale_kv.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { + +template +class LargeScaleFuseSGDOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override; +}; + +template +class LargeScaleFuseSGDOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *learning_rate = ctx.Input("LearningRate"); + + const auto *grad_var = ctx.InputVar("Grad"); + + PADDLE_ENFORCE( + grad_var->IsType(), + platform::errors::InvalidArgument( + "in large scale optimize, gradient should only be SelectedRows")); + + const auto &grad = grad_var->Get(); + + // for distributed training, a sparse var may be empty, + // just skip updating. + if (grad.rows().size() == 0) { + return; + } + + framework::SelectedRows tmp_grad_merge; + const framework::SelectedRows *grad_merge_ptr; + math::scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), grad, + &tmp_grad_merge, true); + grad_merge_ptr = &tmp_grad_merge; + + std::vector in_rows; + in_rows.reserve(grad_merge_ptr->rows().size()); + std::copy(grad_merge_ptr->rows().begin(), grad_merge_ptr->rows().end(), + std::back_inserter(in_rows)); + + const auto *lr = learning_rate->data(); + auto grad_v = grad_merge_ptr->value(); + auto grad_width = grad_v.dims()[1]; + + // auto is_entry = context.Attr("is_entry"); + auto tablename = ctx.Attr("tablename"); + auto value_names = ctx.Attr>("value_names"); + + std::vector *>> values; + std::vector dims; + + auto *ins = distributed::LargeScaleKV::GetInstance(); + auto *table = ins->Get(tablename); + table->Get(in_rows, value_names, &values); + table->Dims({"Param"}, &dims); + + PADDLE_ENFORCE_EQ(dims[0], grad_width, + platform::errors::InvalidArgument( + "param_row should have the same size with grad_row")); + + auto blas = math::GetBlas(ctx); + + std::vector grads; + framework::TensorToVector(grad_v, ctx.device_context(), &grads); + + blas.SCAL(grads.size(), lr[0], grads.data()); + + for (int x = 0; x < static_cast(in_rows.size()); ++x) { + auto ¶ms = values[x][0]; + blas.VSUB(grad_width, params->data(), grads.data() + grad_width * x, + params->data()); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed_ops/recv_op.cc b/paddle/fluid/operators/distributed_ops/recv_op.cc index 15b36baead..2547ba3acb 100644 --- a/paddle/fluid/operators/distributed_ops/recv_op.cc +++ b/paddle/fluid/operators/distributed_ops/recv_op.cc @@ -37,12 +37,6 @@ class RecvOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - int do_not_run = Attr("do_not_run"); - if (do_not_run) { - VLOG(3) << "recv do not run!"; - return; - } - std::vector epmap = Attr>("epmap"); std::vector varnames = Attr>("varnames"); @@ -63,11 +57,10 @@ class RecvOp : public framework::OperatorBase { if (recv_varnames.size() > 0) { auto *communicator = distributed::Communicator::GetInstance(); - if (communicator == nullptr) { + if (communicator != nullptr) { PADDLE_THROW(platform::errors::InvalidArgument( - "need run fleet.init_worker first")); + "execute startup program must before fleet.init_worker")); } - communicator->RecvNoBarrier(); } else { std::vector rets; if (with_barrier) { diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index ae5c53b8a3..6dd4661f00 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -220,12 +220,12 @@ class ParameterServerRuntime(RuntimeBase): else: model_dirname = None - if self.role_maker._is_heter_worker(): - self._init_worker() - executor = self._get_executor() executor.run(fluid.default_startup_program()) + if self.role_maker._is_heter_worker(): + self._init_worker() + if self.role_maker._is_heter_worker(): return diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 236cb458be..e556a98ed7 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -191,12 +191,14 @@ class FleetTranspiler(Fleet): self._communicator = Communicator( trainer_config.mode, kwargs, trainer_config.get_communicator_flags()) + self._communicator.init_with_ctx(send_ctx, recv_ctx) if not self._communicator.is_running(): self._communicator.start() else: - warnings.warn("communicator has been initialized, skip") + raise ValueError( + "Communicator can only be inited once, please check") def init_worker(self): """ diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py index 05deff10a2..a60c4e149f 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/pserver_pass.py @@ -624,6 +624,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): value_dims = [] grad = None opt_idx = -1 + fuse = False for op in block.ops: opt_idx += 1 @@ -631,6 +632,9 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): if op.type not in opt_value_map.keys(): continue + if op.type in ["sgd", "adam"]: + fuse = True + grad = main_program.global_block().vars[op.input("Grad")[0]] for value in opt_value_map[op.type]: @@ -644,7 +648,67 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): if value_names: break - return grad, opt_idx, value_names, value_dims, acture_names + return grad, opt_idx, value_names, value_dims, acture_names, fuse + + def add_fuse_large_scale_op(block, global_block, table_name, value_names, + acture_names, grad, is_entry, opt_idx): + + op = block.ops[opt_idx] + + if op.type == "sgd": + grad = main_program.global_block().vars[op.input("Grad")[0]] + lr = main_program.global_block().vars[op.input("LearningRate")[0]] + + block._insert_op( + opt_idx, + type="lookup_sparse_table_fuse_sgd", + inputs={"Grad": grad, + "LearningRate": lr}, + attrs={ + "is_entry": is_entry, + "tablename": table_name, + "value_names": value_names + }) + + elif op.type == "adam": + grad = main_program.global_block().vars[op.input("Grad")[0]] + lr = main_program.global_block().vars[op.input("LearningRate")[0]] + beta1_pow = main_program.global_block().vars[op.input("Beta1Pow")[ + 0]] + beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[ + 0]] + beta1_pow_o = main_program.global_block().vars[op.output( + "Beta1PowOut")[0]] + beta2_pow_o = main_program.global_block().vars[op.output( + "Beta2PowOut")[0]] + + beta1 = op.attr('beta1') + beta2 = op.attr('beta2') + epsilon = op.attr('epsilon') + + block._insert_op( + opt_idx, + type="lookup_sparse_table_fuse_adam", + inputs={ + "Grad": grad, + "LearningRate": lr, + "Beta1Pow": beta1_pow, + "Beta2Pow": beta2_pow + }, + outputs={ + "Beta1PowOut": beta1_pow_o, + "Beta2PowOut": beta2_pow_o + }, + attrs={ + "beta1": beta1, + "beta2": beta2, + "epsilon": epsilon, + "is_entry": is_entry, + "tablename": table_name, + "value_names": value_names + }) + else: + raise ValueError("only support sgd/adam optimizer now") def add_large_scale_op(block, global_block, table_name, value_names, acture_names, grad, is_entry, opt_idx): @@ -711,24 +775,35 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False): for param, blockid in param_blockid_map.items(): opt_block = program.block(blockid) - grad, opt_idx, value_names, value_dims, acture_names = \ + grad, opt_idx, value_names, value_dims, acture_names, fuse = \ get_optimizer_values(opt_block) entry_attr = get_entry_attr(param) is_entry = False if entry_attr == "none" else True - add_large_scale_op(opt_block, - program.global_block(), param, value_names, - acture_names, grad, is_entry, opt_idx) + if fuse: + add_fuse_large_scale_op(opt_block, + program.global_block(), param, + value_names, acture_names, grad, + is_entry, opt_idx) + else: + add_large_scale_op(opt_block, + program.global_block(), param, value_names, + acture_names, grad, is_entry, opt_idx) else: large_scale_kv_metas = [] for param, blockid in param_blockid_map.items(): opt_block = main_program.block(blockid) - grad, _, value_names, value_dims, acture_names = \ + + grad, opt_idx, value_names, value_dims, acture_names, fuse = \ get_optimizer_values(opt_block) entry_attr = get_entry_attr(param) + if fuse: + # remove origin optimzier op + opt_block._remove_op(opt_idx) + # training/infer mode = "0" names_str = ",".join(value_names) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 4543af9820..3f826da3ae 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -227,22 +227,6 @@ def init_from_server_pass(program, config): fetch_barrier_out = program.global_block().create_var( name=framework.generate_control_dev_var_name()) - recv_ctx = config.get_communicator_recv_context(recv_type=1) - recv_varnames = [] - - for name, ctxs in recv_ctx.items(): - recv_varnames.extend(ctxs.origin_varnames()) - - program.global_block().append_op( - type="recv", - inputs={"X": []}, - outputs={"Out": []}, - attrs={ - "recv_varnames": recv_varnames, - "trainer_id": config.get_role_id(), - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) - program.global_block().append_op( type="fetch_barrier", inputs={}, diff --git a/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py b/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py index 60378aa982..06a90b78fd 100644 --- a/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py +++ b/python/paddle/fluid/incubate/fleet/tests/fleet_deep_ctr.py @@ -164,8 +164,8 @@ def train(args): elif fleet.is_worker(): logger.info("run trainer") - fleet.init_worker() exe.run(fleet.startup_program) + fleet.init_worker() thread_num = 2 filelist = [] diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 8277499fcc..5721445c41 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -163,8 +163,10 @@ class TestDistCTR2x2(FleetDistRunnerBase): """ exe = fluid.Executor(fluid.CPUPlace()) - fleet.init_worker() + exe.run(fluid.default_startup_program()) + fleet.init_worker() + batch_size = 4 train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) self.reader.decorate_sample_list_generator(train_reader) @@ -202,8 +204,8 @@ class TestDistCTR2x2(FleetDistRunnerBase): exe = fluid.Executor(fluid.CPUPlace()) - fleet.init_worker() exe.run(fluid.default_startup_program()) + fleet.init_worker() thread_num = 2 batch_size = 128 diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py index 0e3c809927..3852b22523 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr_ps_gpu.py @@ -60,8 +60,9 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace(device_id) exe = fluid.Executor(place) - fleet.init_worker() + exe.run(fleet.startup_program) + fleet.init_worker() batch_size = 4 train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) @@ -104,8 +105,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): place = fluid.CUDAPlace(device_id) exe = fluid.Executor(place) - fleet.init_worker() exe.run(fleet.startup_program) + fleet.init_worker() thread_num = 2 batch_size = 128 diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py index 2f938a813d..470fb98d79 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py @@ -152,8 +152,9 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): """ exe = fluid.Executor(fluid.CPUPlace()) - fleet.init_worker() exe.run(fluid.default_startup_program()) + fleet.init_worker() + batch_size = 4 train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) self.reader.decorate_sample_list_generator(train_reader) @@ -176,8 +177,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): exe = fluid.Executor(fluid.CPUPlace()) - fleet.init_worker() exe.run(fluid.default_startup_program()) + fleet.init_worker() thread_num = int(os.getenv("CPU_NUM", 2)) batch_size = 128 diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_simnet_bow.py b/python/paddle/fluid/tests/unittests/dist_fleet_simnet_bow.py index 2ea69e1b67..ff84848873 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_simnet_bow.py @@ -222,8 +222,8 @@ class TestDistSimnetBow2x2(FleetDistRunnerBase): """ exe = fluid.Executor(fluid.CPUPlace()) - fleet.init_worker() exe.run(fluid.default_startup_program()) + fleet.init_worker() batch_size = 4 # reader train_reader = paddle.batch(fake_simnet_reader(), batch_size=batch_size) diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py index 77697896b4..81530573a6 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_sparse_embedding_ctr.py @@ -151,8 +151,9 @@ class TestDistCTR2x2(FleetDistRunnerBase): """ exe = fluid.Executor(fluid.CPUPlace()) - fleet.init_worker() + exe.run(fluid.default_startup_program()) + fleet.init_worker() batch_size = 4 diff --git a/python/paddle/fluid/tests/unittests/test_communicator_async.py b/python/paddle/fluid/tests/unittests/test_communicator_async.py index d032d6d75b..a86b80b2cf 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_async.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_async.py @@ -30,11 +30,10 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distribu class TestCommunicator(unittest.TestCase): def net(self): - x = fluid.layers.data(name='x', shape=[13], dtype='float32') - y_predict = fluid.layers.fc(input=x, size=1, act=None) + x = fluid.layers.data(name='x', shape=[1], dtype='float32') y = fluid.layers.data(name='y', shape=[1], dtype='float32') - cost = fluid.layers.square_error_cost(input=y_predict, label=y) + cost = fluid.layers.square_error_cost(input=x, label=y) avg_cost = fluid.layers.mean(cost) return avg_cost diff --git a/python/paddle/fluid/tests/unittests/test_communicator_geo.py b/python/paddle/fluid/tests/unittests/test_communicator_geo.py index d9fc9262b3..5916000fba 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_geo.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_geo.py @@ -83,8 +83,8 @@ class TestCommunicatorGeoEnd2End(unittest.TestCase): optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(avg_cost) - fleet.init_worker() exe.run(fluid.default_startup_program()) + fleet.init_worker() train_reader = paddle.batch(self.fake_reader(), batch_size=24) feeder = fluid.DataFeeder(place=place, feed_list=[x, z, y]) diff --git a/python/paddle/fluid/tests/unittests/test_communicator_half_async.py b/python/paddle/fluid/tests/unittests/test_communicator_half_async.py index 391588780f..b0f55f2939 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_half_async.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_half_async.py @@ -71,8 +71,8 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase): optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(avg_cost) - fleet.init_worker() exe.run(fleet.startup_program) + fleet.init_worker() train_reader = paddle.batch(self.fake_reader(), batch_size=24) feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) diff --git a/python/paddle/fluid/tests/unittests/test_communicator_sync.py b/python/paddle/fluid/tests/unittests/test_communicator_sync.py index c0044d9d62..95b209b146 100644 --- a/python/paddle/fluid/tests/unittests/test_communicator_sync.py +++ b/python/paddle/fluid/tests/unittests/test_communicator_sync.py @@ -27,11 +27,9 @@ import paddle.distributed.fleet as fleet class TestCommunicator(unittest.TestCase): def net(self): - x = fluid.layers.data(name='x', shape=[13], dtype='float32') - y_predict = fluid.layers.fc(input=x, size=1, act=None) + x = fluid.layers.data(name='x', shape=[1], dtype='float32') y = fluid.layers.data(name='y', shape=[1], dtype='float32') - - cost = fluid.layers.square_error_cost(input=y_predict, label=y) + cost = fluid.layers.square_error_cost(input=x, label=y) avg_cost = fluid.layers.mean(cost) return avg_cost diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_async.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_async.py index a82612b0ed..7f55e956a9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_async.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_async.py @@ -44,16 +44,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): paddle.fluid.framework.switch_startup_program(startup_program) fleet.init(role_maker.PaddleCloudRoleMaker()) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) + x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32') + y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = paddle.fluid.layers.square_error_cost(input=x, label=y) + avg_cost = paddle.fluid.layers.mean(cost) strategy = paddle.distributed.fleet.DistributedStrategy() strategy.a_sync = True @@ -71,7 +66,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): sends += 1 if op.type == "sgd": sgds += 1 - self.assertEqual(sends, 7) + self.assertEqual(sends, 1) self.assertEqual(sgds, 0) fleet.init_worker() @@ -89,16 +84,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): paddle.fluid.framework.switch_startup_program(startup_program) fleet.init(role_maker.PaddleCloudRoleMaker()) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - - fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) + + x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32') + y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = paddle.fluid.layers.square_error_cost(input=x, label=y) + avg_cost = paddle.fluid.layers.mean(cost) strategy = paddle.distributed.fleet.DistributedStrategy() strategy.a_sync = True diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py index b05a53c88b..db3f2afb36 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_a_sync_optimizer_sync.py @@ -36,16 +36,11 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): def test_gradient_merge_optimizer(self): fleet.init(role_maker.PaddleCloudRoleMaker()) - input_x = paddle.fluid.layers.data( - name="x", shape=[32], dtype='float32') - input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') - fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') - fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') - prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') - cost = paddle.fluid.layers.cross_entropy( - input=prediction, label=input_y) - avg_cost = paddle.fluid.layers.mean(x=cost) + x = paddle.fluid.layers.data(name='x', shape=[1], dtype='float32') + y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = paddle.fluid.layers.square_error_cost(input=x, label=y) + avg_cost = paddle.fluid.layers.mean(cost) strategy = paddle.distributed.fleet.DistributedStrategy() strategy.a_sync = False @@ -63,7 +58,7 @@ class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): sends += 1 if op.type == "sgd": sgds += 1 - self.assertEqual(sends, 6) + self.assertEqual(sends, 0) self.assertEqual(sgds, 0) fleet.init_worker() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps4.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps4.py index 379bcaf684..6fe52ba9fe 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps4.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps4.py @@ -70,15 +70,13 @@ class TestPSPassWithBow(unittest.TestCase): q = fluid.layers.data( name="query_ids", shape=[1], dtype="int64", lod_level=1) # embedding - q_emb = fluid.layers.embedding( + q_emb = fluid.contrib.layers.sparse_embedding( input=q, - is_distributed=is_distributed, size=[dict_dim, emb_dim], param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=0.01), name="__emb__", - learning_rate=emb_lr), - is_sparse=is_sparse) + learning_rate=emb_lr)) q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim]) # vsum q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') @@ -97,15 +95,13 @@ class TestPSPassWithBow(unittest.TestCase): pt = fluid.layers.data( name="pos_title_ids", shape=[1], dtype="int64", lod_level=1) # embedding - pt_emb = fluid.layers.embedding( + pt_emb = fluid.contrib.layers.sparse_embedding( input=pt, - is_distributed=is_distributed, size=[dict_dim, emb_dim], param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=0.01), name="__emb__", - learning_rate=emb_lr), - is_sparse=is_sparse) + learning_rate=emb_lr)) pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim]) # vsum pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') @@ -123,15 +119,13 @@ class TestPSPassWithBow(unittest.TestCase): nt = fluid.layers.data( name="neg_title_ids", shape=[1], dtype="int64", lod_level=1) # embedding - nt_emb = fluid.layers.embedding( + nt_emb = fluid.contrib.layers.sparse_embedding( input=nt, - is_distributed=is_distributed, size=[dict_dim, emb_dim], param_attr=fluid.ParamAttr( initializer=fluid.initializer.Constant(value=0.01), name="__emb__", - learning_rate=emb_lr), - is_sparse=is_sparse) + learning_rate=emb_lr)) nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim]) # vsum nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') @@ -167,7 +161,7 @@ class TestPSPassWithBow(unittest.TestCase): fleet.init(role) loss, acc, _ = self.net() - optimizer = fluid.optimizer.SGD(base_lr) + optimizer = fluid.optimizer.Adam(base_lr) strategy = StrategyFactory.create_async_strategy() optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(loss) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps5.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps5.py index fd06979347..c570c4d8cd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps5.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps5.py @@ -168,12 +168,13 @@ class TestPSPassWithBow(unittest.TestCase): fleet.init(role) loss, acc, _ = self.net() - optimizer = fluid.optimizer.SGD( + optimizer = fluid.optimizer.Adagrad( learning_rate=fluid.layers.exponential_decay( learning_rate=base_lr, decay_steps=500, decay_rate=0.969, staircase=True)) + strategy = StrategyFactory.create_async_strategy() optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(loss) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps6.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps6.py new file mode 100644 index 0000000000..d5b1284e3c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps6.py @@ -0,0 +1,168 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +import paddle.fluid.incubate.fleet.base.role_maker as role_maker +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory + +# For Net +base_lr = 0.2 +emb_lr = base_lr * 3 +dict_dim = 1500 +emb_dim = 128 +hid_dim = 128 +margin = 0.1 +sample_rate = 1 +batch_size = 4 + + +class TestPSPassWithBow(unittest.TestCase): + def net(self): + def get_acc(cos_q_nt, cos_q_pt, batch_size): + cond = fluid.layers.less_than(cos_q_nt, cos_q_pt) + cond = fluid.layers.cast(cond, dtype='float64') + cond_3 = fluid.layers.reduce_sum(cond) + acc = fluid.layers.elementwise_div( + cond_3, + fluid.layers.fill_constant( + shape=[1], value=batch_size * 1.0, dtype='float64'), + name="simnet_acc") + return acc + + def get_loss(cos_q_pt, cos_q_nt): + loss_op1 = fluid.layers.elementwise_sub( + fluid.layers.fill_constant_batch_size_like( + input=cos_q_pt, + shape=[-1, 1], + value=margin, + dtype='float32'), + cos_q_pt) + loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt) + loss_op3 = fluid.layers.elementwise_max( + fluid.layers.fill_constant_batch_size_like( + input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'), + loss_op2) + avg_cost = fluid.layers.mean(loss_op3) + return avg_cost + + is_distributed = False + is_sparse = True + + # query + q = fluid.layers.data( + name="query_ids", shape=[1], dtype="int64", lod_level=1) + # embedding + q_emb = fluid.contrib.layers.sparse_embedding( + input=q, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr)) + q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim]) + # vsum + q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') + q_ss = fluid.layers.softsign(q_sum) + # fc layer after conv + q_fc = fluid.layers.fc( + input=q_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__q_fc__", + learning_rate=base_lr)) + # label data + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + # pt + pt = fluid.layers.data( + name="pos_title_ids", shape=[1], dtype="int64", lod_level=1) + # embedding + pt_emb = fluid.contrib.layers.sparse_embedding( + input=pt, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr)) + pt_emb = fluid.layers.reshape(pt_emb, [-1, emb_dim]) + # vsum + pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') + pt_ss = fluid.layers.softsign(pt_sum) + # fc layer + pt_fc = fluid.layers.fc( + input=pt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + # nt + nt = fluid.layers.data( + name="neg_title_ids", shape=[1], dtype="int64", lod_level=1) + # embedding + nt_emb = fluid.contrib.layers.sparse_embedding( + input=nt, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__emb__", + learning_rate=emb_lr)) + nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim]) + # vsum + nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') + nt_ss = fluid.layers.softsign(nt_sum) + # fc layer + nt_fc = fluid.layers.fc( + input=nt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01), + name="__fc__", + learning_rate=base_lr), + bias_attr=fluid.ParamAttr(name="__fc_b__")) + cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc) + cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc) + # loss + avg_cost = get_loss(cos_q_pt, cos_q_nt) + # acc + acc = get_acc(cos_q_nt, cos_q_pt, batch_size) + return [avg_cost, acc, cos_q_pt] + + def test(self): + endpoints = [ + "127.0.0.1:36004", "127.0.0.1:36005", "127.0.0.1:36006", + "127.0.0.1:36007" + ] + + role = role_maker.UserDefinedRoleMaker( + current_id=0, + role=role_maker.Role.SERVER, + worker_num=2, + server_endpoints=endpoints) + + fleet.init(role) + loss, acc, _ = self.net() + optimizer = fluid.optimizer.Adagrad(base_lr) + strategy = StrategyFactory.create_async_strategy() + optimizer = fleet.distributed_optimizer(optimizer, strategy) + optimizer.minimize(loss) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_lookup_sparse_table_fuse_ops.py b/python/paddle/fluid/tests/unittests/test_dist_lookup_sparse_table_fuse_ops.py new file mode 100644 index 0000000000..bca91c536b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_lookup_sparse_table_fuse_ops.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.core as core + + +class TestLookupTableFuseOp(unittest.TestCase): + def test_fuse(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + def check_with_place(self, place): + scope = fluid.global_scope() + scope.var("LearningRate").get_tensor().set([0.01], place) + scope.var("Ids").get_tensor().set([i for i in range(100)], place) + + init_program = fluid.Program() + + lr = init_program.global_block().create_var( + name="LearningRate", + persistable=True, + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + dtype="float32") + + ids = init_program.global_block().create_var( + name="Ids", + persistable=True, + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + shape=[100], + dtype="int64") + + output = init_program.global_block().create_var( + name="output", + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + shape=[100, 8], + dtype="float32") + + metas = [] + metas.append( + "embedding_1.block0:Param,Moment1,Moment2:8,8,8:0:embedding_1@GRAD.block0:embedding_1.block0,embedding_1_moment1_0,embedding_1_moment2_0,kSparseIDs@embedding_1.block0:uniform_random&0&-0.5&0.5,fill_constant&0.0,fill_constant&0.0:none" + ) + metas.append( + "embedding_2.block0:Param:8:0:embedding_2@GRAD.block0:embedding_2.block0,kSparseIDs@embedding_2.block0:uniform_random&0&-0.5&0.5:none" + ) + + init_program.global_block().append_op( + type="lookup_sparse_table_init", + inputs=None, + outputs=None, + attrs={"large_scale_metas": metas}) + + init_program.global_block().append_op( + type="lookup_sparse_table_read", + inputs={"Ids": ids}, + outputs={"Out": output}, + attrs={ + "tablename": "embedding_1.block0", + "init": True, + "value_names": ["Param"], + }) + + init_program.global_block().append_op( + type="lookup_sparse_table_read", + inputs={"Ids": ids}, + outputs={"Out": output}, + attrs={ + "tablename": "embedding_2.block0", + "init": True, + "value_names": ["Param"], + }) + + executor = fluid.Executor(place) + executor.run(init_program) + + training_program = fluid.Program() + + scope.var('Beta1Pow').get_tensor().set( + np.array([0]).astype("float32"), place) + scope.var('Beta2Pow').get_tensor().set( + np.array([0]).astype("float32"), place) + + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 8 + w_selected_rows = scope.var('Grad').get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("float32") + for i in range(len(rows)): + w_array[i] *= i + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) + + lr = training_program.global_block().create_var( + name="LearningRate", + persistable=True, + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + dtype="float32") + + grads = training_program.global_block().create_var( + name="Grad", + persistable=True, + type=fluid.core.VarDesc.VarType.SELECTED_ROWS, + shape=[100, 8], + dtype="float32") + + beta1 = training_program.global_block().create_var( + name="Beta1Pow", + persistable=True, + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + dtype="float32") + + beta2 = training_program.global_block().create_var( + name="Beta2Pow", + persistable=True, + type=fluid.core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + dtype="float32") + + training_program.global_block().append_op( + type="lookup_sparse_table_fuse_adam", + inputs={ + "Grad": grads, + "LearningRate": lr, + "Beta1Pow": beta1, + "Beta2Pow": beta2, + }, + outputs={"Beta1PowOut": beta1, + "Beta2PowOut": beta2}, + attrs={ + "is_entry": False, + "tablename": "embedding_1.block0", + "value_names": ["Param", "Moment1", "Moment2"], + }) + + training_program.global_block().append_op( + type="lookup_sparse_table_fuse_sgd", + inputs={"Grad": grads, + "LearningRate": lr}, + attrs={ + "is_entry": False, + "tablename": "embedding_2.block0", + "value_names": ["Param"], + }) + + executor.run(training_program) + + +if __name__ == "__main__": + unittest.main() -- GitLab