未验证 提交 3a45767d 编写于 作者: X xujiaqi01 提交者: GitHub

add fleet pslib pull and push sparse op and push dense op (#23139)

* add fleet pslib pull and push sparse op and push dense op
* test=develop
上级 0536b526
......@@ -66,6 +66,8 @@ class PullDenseWorker {
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
void PullDense(bool force_update = false);
int GetThreadIdByScope(const Scope* scope);
void SetThreadIdByScope(const Scope* scope, int tid);
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
......@@ -73,13 +75,14 @@ class PullDenseWorker {
return s_instance_;
}
static std::shared_ptr<PullDenseWorker> s_instance_;
private:
PullDenseWorker() : root_scope_(NULL) {}
void Run();
bool CheckUpdateParam(uint64_t table_id);
private:
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
DownpourWorkerParameter dwp_param_;
......@@ -105,6 +108,7 @@ class PullDenseWorker {
float squared_sum_epsilon_ = 1e-4;
std::mutex mutex_for_mean_scale_;
float total_batch_num_ = 0;
std::unordered_map<const Scope*, int> scope_to_thread_id_;
};
// should incorporate different type of device
......
......@@ -124,6 +124,22 @@ void DistMultiTrainer::FinalizeDumpEnv() {
queue_.reset();
}
void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program,
const platform::Place &place) {
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetPlace(place);
workers_[i]->SetReaderPlace(place);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->CreateDeviceResource(main_program); // Program
workers_[i]->BindingDataFeedMemory();
}
// Scope* -> thread id, it will be used in push_dense op
for (int i = 0; i < thread_num_; ++i) {
Scope *thread_scope = workers_[i]->GetThreadScope();
pull_dense_worker_->SetThreadIdByScope(thread_scope, i);
}
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) {
if (need_dump_field_) {
InitDumpEnv();
......
......@@ -358,6 +358,66 @@ void FleetWrapper::PullSparseVarsSync(
#endif
}
void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id,
platform::Place place,
std::vector<const LoDTensor*>* inputs,
std::vector<LoDTensor*>* outputs) {
#ifdef PADDLE_WITH_PSLIB
std::vector<uint64_t> fea_keys;
std::vector<float*> pull_result_ptr;
fea_keys.reserve(MAX_FEASIGN_NUM / 100);
pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100);
std::vector<float> init_value(fea_dim, 0);
framework::LoDTensor* output = nullptr;
float* output_data = nullptr;
size_t output_index = -1;
size_t output_len = 0;
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor* tensor = inputs->at(index);
const int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
if (!output || output_len == size_t(output->numel())) {
++output_index;
CHECK(output_index < outputs->size()); // NOLINT
output = outputs->at(output_index);
output_data = output->mutable_data<float>(place);
output_len = 0;
CHECK(output->numel() % fea_dim == 0); // NOLINT
CHECK(output_data != nullptr); // NOLINT
}
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
memcpy(output_data + output_len, init_value.data(),
sizeof(float) * fea_dim);
continue;
}
fea_keys.push_back(real_id);
pull_result_ptr.push_back(output_data + output_len);
}
}
auto status = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size());
status.wait();
auto ret = status.get();
if (ret != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]";
sleep(sleep_seconds_before_fail_exit_);
}
#else
for (size_t index = 0; index < inputs->size(); ++index) {
auto* tensor = inputs->at(index);
size_t len = tensor->numel();
std::vector<float> init_data(fea_dim, 0);
for (size_t i = 0; i < len; ++i) {
memcpy(outputs->at(index)->mutable_data<float>(place), init_data.data(),
fea_dim);
}
}
#endif
}
void FleetWrapper::PullDenseVarsAsync(
const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names,
......@@ -454,9 +514,12 @@ void FleetWrapper::PushDenseVarsAsync(
paddle::ps::Region reg(g, count);
regions.emplace_back(std::move(reg));
}
auto status = pslib_ptr_->_worker_ptr->push_dense(regions.data(),
regions.size(), table_id);
push_sparse_status->push_back(std::move(status));
if (push_sparse_status) {
push_sparse_status->push_back(std::move(status));
}
#endif
}
......@@ -598,6 +661,142 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
void FleetWrapper::PushSparseFromTensorWithLabelAsync(
const Scope& scope, const uint64_t table_id, int fea_dim,
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
const std::string& click_name, platform::Place place,
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs,
std::vector<const LoDTensor*>* outputs) {
#ifdef PADDLE_WITH_PSLIB
int show_index = 0;
int click_index = 1;
// these default values can not be used, it must be set.
bool dump_slot = false;
int slot_offset = 0;
int grad_dim = 0;
// don't worry, user do not have to care about all these flags
if (accesor == "DownpourCtrAccessor") {
dump_slot = true;
slot_offset = 1;
grad_dim = fea_dim - 2;
show_index = 1;
click_index = 2;
} else if (accesor == "DownpourFeatureValueAccessor") {
dump_slot = false;
slot_offset = 0;
grad_dim = fea_dim - 2;
} else if (accesor == "DownpourSparseValueAccessor") {
dump_slot = false;
slot_offset = 0;
grad_dim = fea_dim;
}
CHECK(grad_dim >= 0); // NOLINT
int batch_size = -1;
for (auto* input : *inputs) {
int cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
CHECK(batch_size == cur_batch_size); // NOLINT
}
}
CHECK(batch_size > 0); // NOLINT
std::vector<float> g;
for (const framework::LoDTensor* g_tensor : *outputs) {
size_t origin = g.size();
size_t add = g_tensor->numel();
g.resize(origin + add);
memcpy(g.data() + origin, g_tensor->data<float>(), add);
}
if (scale_sparse && grad_dim > 0) {
size_t dim = static_cast<size_t>(grad_dim);
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g.data(), g.size() / dim, dim);
g_mat.rightCols(grad_dim) *= batch_size;
}
std::vector<float> fea_labels;
fea_labels.reserve(MAX_FEASIGN_NUM / 100);
framework::Variable* var = scope.FindVar(click_name);
size_t global_idx = 0;
if (click_name != "") {
CHECK(var != nullptr); // NOLINT
framework::LoDTensor* label_tensor =
var->GetMutable<framework::LoDTensor>();
CHECK(label_tensor != nullptr); // NOLINT
int64_t* label_ptr = label_tensor->data<int64_t>();
for (auto* tensor : *inputs) {
const int64_t* ids = tensor->data<int64_t>();
size_t fea_idx = 0;
for (size_t lod_idx = 1; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
size_t cur =
GetAbsoluteSum(tensor->lod()[0][lod_idx - 1],
tensor->lod()[0][lod_idx], 0, tensor->lod());
for (size_t i = 0; i < cur; ++i, ++fea_idx) {
if (static_cast<uint64_t>(ids[fea_idx]) == padding_id) {
continue;
}
fea_labels.push_back(static_cast<float>(label_ptr[lod_idx - 1]));
++global_idx;
}
}
}
}
std::vector<uint64_t> push_keys;
push_keys.reserve(MAX_FEASIGN_NUM / 100);
std::vector<std::vector<float>> push_values;
push_values.reserve(MAX_FEASIGN_NUM / 100);
size_t output_len = 0;
size_t input_idx = 0;
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor* tensor = inputs->at(index);
const int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
if (static_cast<uint64_t>(ids[i]) == padding_id) {
continue;
}
push_keys.emplace_back(ids[i]);
push_values.emplace_back(fea_dim + slot_offset);
float* data = push_values.back().data();
if (!var) {
memcpy(data + slot_offset, g.data() + output_len,
sizeof(float) * fea_dim);
} else {
memcpy(data + slot_offset, g.data() + output_len,
sizeof(float) * grad_dim);
data[show_index] = 1.0f;
data[click_index] = static_cast<float>(fea_labels.at(input_idx));
}
if (dump_slot) {
int slot = boost::lexical_cast<int>(input_names[index]);
data[0] = static_cast<float>(slot);
}
++input_idx;
}
}
CHECK(output_len == g.size()); // NOLINT
if (click_name != "") {
CHECK(input_idx == global_idx); // NOLINT
}
std::vector<float*> push_g_vec(input_idx, nullptr);
for (auto i = 0u; i < push_keys.size(); ++i) {
push_g_vec[i] = push_values.at(i).data();
}
auto status = pslib_ptr_->_worker_ptr->push_sparse(
table_id, push_keys.data(), (const float**)push_g_vec.data(),
push_keys.size());
#endif
}
void FleetWrapper::LoadFromPaddleModel(Scope& scope, const uint64_t table_id,
std::vector<std::string> var_list,
std::string model_path,
......@@ -955,5 +1154,19 @@ int32_t FleetWrapper::CopyTableByFeasign(
#endif
}
size_t FleetWrapper::GetAbsoluteSum(size_t start, size_t end, size_t level,
const framework::LoD& lod) {
if (level >= lod.size() - 1) {
return end - start;
}
size_t ret = 0;
for (size_t i = start; i < end - 1; ++i) {
size_t pos1 = lod[level][i];
size_t pos2 = lod[level][i + 1];
ret += GetAbsoluteSum(pos1, pos2, level + 1, lod);
}
return ret;
}
} // end namespace framework
} // end namespace paddle
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
......@@ -78,8 +79,9 @@ class FleetWrapper {
void SetPullLocalThreadNum(int thread_num) {
pull_local_thread_num_ = thread_num;
}
// Pull sparse variables from server in sync mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
......@@ -87,12 +89,26 @@ class FleetWrapper {
std::vector<std::vector<float>>* fea_values,
int fea_dim,
const std::vector<std::string>& var_emb_names);
// Pull sparse variables from server in async mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values std::future
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// Pull sparse variables from server in sync mode
// pull immediately to tensors
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<LoDTensor*>* outputs); // NOLINT
// pull dense variables from server in sync mod
// Param<in>: scope, table_id, var_names
// Param<out>: void
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
......@@ -134,6 +150,7 @@ class FleetWrapper {
GetLocalTable() {
return local_tables_;
}
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
// sparse_grad_names, batch_size, use_cvm, dump_slot
......@@ -149,6 +166,15 @@ class FleetWrapper {
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in async mode
void PushSparseFromTensorWithLabelAsync(
const Scope& scope, const uint64_t table_id, int fea_dim,
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
const std::string& click_name, platform::Place place,
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
......@@ -255,6 +281,9 @@ class FleetWrapper {
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
#endif
size_t GetAbsoluteSum(size_t start, size_t end, size_t level,
const framework::LoD& lod);
protected:
static bool is_initialized_;
bool scale_sparse_gradient_with_batch_size_;
......
......@@ -140,5 +140,16 @@ void PullDenseWorker::ResetThreadVersion(uint64_t table_id) {
last_versions_[table_id] = current_version_[table_id];
}
int PullDenseWorker::GetThreadIdByScope(const Scope* scope) {
if (scope_to_thread_id_.find(scope) != scope_to_thread_id_.end()) {
return scope_to_thread_id_[scope];
}
return -1;
}
void PullDenseWorker::SetThreadIdByScope(const Scope* scope, int tid) {
scope_to_thread_id_[scope] = tid;
}
} // namespace framework
} // namespace paddle
......@@ -100,6 +100,8 @@ class DistMultiTrainer : public MultiTrainer {
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Run();
virtual void Finalize();
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_sparse_op.h"
#include <string>
namespace paddle {
namespace operators {
class PullSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Ids) of PullSparseOp can not be null"));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Output(Out) of PullSparseOp can not be null"));
auto hidden_size =
static_cast<uint32_t>(ctx->Attrs().Get<int>("EmbeddingDim"));
auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size();
std::vector<framework::DDim> outs_dims;
outs_dims.resize(n_ids);
for (size_t i = 0; i < n_ids; ++i) {
const auto ids_dims = all_ids_dim[i];
int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension of "
" the 'Ids' tensor must be 1.",
i));
auto out_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_dim.push_back(hidden_size);
outs_dims[i] = framework::make_ddim(out_dim);
}
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < n_ids; ++i) {
ctx->ShareLoD("Ids", "Out", i, i);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
class PullSparseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"Input tensors with type int64 contains "
"the ids to be looked up in PSLib. "
"The last dimension size must be 1.")
.AsDuplicable();
AddInput("W", "The lookup table tensors.").AsDuplicable();
AddOutput("Out", "The lookup results tensors.").AsDuplicable();
AddAttr<int>("EmbeddingDim", "(int, the embedding hidden size")
.SetDefault(11);
AddAttr<int>("TableId", "(int, the table id of this embedding")
.SetDefault(0);
AddAttr<std::string>("AccessorClass", "(string, the class name of accessor")
.SetDefault("");
AddAttr<std::string>("CtrLabelName", "(string, ctr label name")
.SetDefault("");
AddAttr<int>("PaddingId", "(int, the padding id of this embedding")
.SetDefault(0);
AddAttr<bool>("ScaleSparseGrad",
"(bool, whether scale sparse gradient with batch size")
.SetDefault(true);
AddAttr<std::vector<std::string>>("InputNames", "(vector, slot names")
.SetDefault(std::vector<std::string>());
AddAttr<bool>("is_distributed", "(bool, it must be true").SetDefault(true);
AddComment(R"DOC(
Pull Sparse Operator.
This operator is used to perform lookups on the PSLib
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
template <typename T>
class PushSparseOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("push_sparse");
retv->SetInput("Ids", this->Input("Ids"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("W", this->Input("W"));
retv->SetOutput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
}
};
class PushSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(pull_sparse, ops::PullSparseOp, ops::PullSparseOpMaker,
ops::PushSparseOpMaker<paddle::framework::OpDesc>,
ops::PushSparseOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(push_sparse, ops::PushSparseOp);
REGISTER_OP_CPU_KERNEL(pull_sparse, ops::PullSparseCPUKernel<float>)
REGISTER_OP_CPU_KERNEL(push_sparse, ops::PushSparseCPUKernel<float>)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
template <typename T>
void PullSparseFunctor(const framework::ExecutionContext& ctx) {
auto inputs = ctx.MultiInput<framework::LoDTensor>("Ids");
auto outputs = ctx.MultiOutput<framework::LoDTensor>("Out");
uint32_t fea_dim = static_cast<uint32_t>(ctx.Attr<int>("EmbeddingDim"));
uint64_t padding_id = static_cast<uint64_t>(ctx.Attr<int>("PaddingId"));
auto table_id = static_cast<uint32_t>(ctx.Attr<int>("TableId"));
// note: GetInstance() is not thread-safe
// we assume FleetWrapper has been already initialized
auto fleet_ptr = framework::FleetWrapper::GetInstance();
fleet_ptr->PullSparseToTensorSync(table_id, fea_dim, padding_id,
ctx.GetPlace(), &inputs, &outputs);
}
template <typename T>
void PushSparseFunctor(const framework::ExecutionContext& ctx) {
auto inputs = ctx.MultiInput<framework::LoDTensor>("Ids");
auto grads =
ctx.MultiInput<framework::LoDTensor>(framework::GradVarName("Out"));
uint32_t fea_dim = static_cast<uint32_t>(ctx.Attr<int>("EmbeddingDim"));
std::string accesor = ctx.Attr<std::string>("AccessorClass");
bool scale_sparse = ctx.Attr<bool>("ScaleSparseGrad");
uint64_t padding_id = static_cast<uint64_t>(ctx.Attr<int>("PaddingId"));
const std::string& label_name = ctx.Attr<std::string>("CtrLabelName");
const framework::Scope& scope = ctx.scope();
auto input_names = ctx.Attr<std::vector<std::string>>("InputNames");
auto table_id = static_cast<uint32_t>(ctx.Attr<int>("TableId"));
// note: GetInstance() is not thread-safe
// we assume FleetWrapper has been already initialized
auto fleet_ptr = framework::FleetWrapper::GetInstance();
fleet_ptr->PushSparseFromTensorWithLabelAsync(
scope, table_id, fea_dim, padding_id, scale_sparse, accesor, label_name,
ctx.GetPlace(), input_names, &inputs, &grads);
}
template <typename T>
class PullSparseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PullSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushSparseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PushSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_sparse_v2_op.h"
#include <string>
namespace paddle {
namespace operators {
class PullSparseV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Ids) of PullSparseV2Op can not be null"));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Output(Out) of PullSparseV2Op can not be null"));
auto hidden_size =
static_cast<uint32_t>(ctx->Attrs().Get<int>("EmbeddingDim"));
auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size();
std::vector<framework::DDim> outs_dims;
outs_dims.resize(n_ids);
for (size_t i = 0; i < n_ids; ++i) {
const auto ids_dims = all_ids_dim[i];
auto out_dim = framework::vectorize(ids_dims);
out_dim.push_back(hidden_size);
outs_dims[i] = framework::make_ddim(out_dim);
}
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < n_ids; ++i) {
ctx->ShareLoD("Ids", "Out", i, i);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
class PullSparseV2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"Input tensors with type int64 contains "
"the ids to be looked up in PSLib. ")
.AsDuplicable();
AddInput("W", "The lookup table tensors.").AsDuplicable();
AddOutput("Out", "The lookup results tensors.").AsDuplicable();
AddAttr<int>("EmbeddingDim", "(int, the embedding hidden size")
.SetDefault(11);
AddAttr<int>("TableId", "(int, the table id of this embedding")
.SetDefault(0);
AddAttr<std::string>("AccessorClass", "(string, the class name of accessor")
.SetDefault("");
AddAttr<std::string>("CtrLabelName", "(string, ctr label name")
.SetDefault("");
AddAttr<int>("PaddingId", "(int, the padding id of this embedding")
.SetDefault(0);
AddAttr<bool>("ScaleSparseGrad",
"(bool, whether scale sparse gradient with batch size")
.SetDefault(true);
AddAttr<std::vector<std::string>>("InputNames", "(vector, slot names")
.SetDefault(std::vector<std::string>());
AddAttr<bool>("is_distributed", "(bool, it must be true").SetDefault(true);
AddComment(R"DOC(
Pull Sparse V2 Operator.
This operator is used to perform lookups on the PSLib
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
template <typename T>
class PushSparseV2OpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("push_sparse_v2");
retv->SetInput("Ids", this->Input("Ids"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("W", this->Input("W"));
retv->SetOutput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
}
};
class PushSparseV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(pull_sparse_v2, ops::PullSparseV2Op, ops::PullSparseV2OpMaker,
ops::PushSparseV2OpMaker<paddle::framework::OpDesc>,
ops::PushSparseV2OpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(push_sparse_v2, ops::PushSparseV2Op);
REGISTER_OP_CPU_KERNEL(pull_sparse_v2, ops::PullSparseV2CPUKernel<float>)
REGISTER_OP_CPU_KERNEL(push_sparse_v2, ops::PushSparseV2CPUKernel<float>)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/pull_sparse_op.h"
namespace paddle {
namespace operators {
template <typename T>
class PullSparseV2CPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PullSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushSparseV2CPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PushSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/push_dense_op.h"
#include <string>
namespace paddle {
namespace operators {
class PushDenseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
platform::errors::InvalidArgument(
"Input(Ids) of PushDenseOp can not be null."));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
class PushDenseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "the tensor to get batch size").AsDuplicable();
AddAttr<int>("TableId", "(int, the table id of this embedding")
.SetDefault(-1);
AddAttr<float>("ScaleDataNorm", "(float, scale data norm gradient")
.SetDefault(-1.0f);
AddAttr<std::vector<std::string>>("InputNames", "(vector, slot names")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC(
Push Dense Operator.
push dense gradients to PSLib's Parameter Server.
The input gradients is all dense gradient tensors in a table.
)DOC");
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInference, "Ids");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
push_dense, ops::PushDenseOp, ops::PushDenseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PushDenseNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
template <typename T>
void PushDenseFunctor(const framework::ExecutionContext& ctx) {
#ifdef PADDLE_WITH_PSLIB
const auto& input_names = ctx.Attr<std::vector<std::string>>("InputNames");
auto table_id = static_cast<uint32_t>(ctx.Attr<int>("TableId"));
PADDLE_ENFORCE_GT(table_id, 0,
platform::errors::InvalidArgument(
"table id should > 0, but value is ", table_id));
float scale_datanorm = ctx.Attr<float>("ScaleDataNorm");
const auto& ids = ctx.MultiInput<framework::LoDTensor>("Ids");
int batch_size =
ids[0]->lod().size() ? ids[0]->lod()[0].size() - 1 : ids[0]->dims()[0];
PADDLE_ENFORCE_GT(batch_size, 0,
platform::errors::InvalidArgument(
"batch size should > 0, but value is ", batch_size));
auto fleet_ptr = framework::FleetWrapper::GetInstance();
fleet_ptr->PushDenseVarsAsync(ctx.scope(), table_id, input_names, nullptr,
scale_datanorm, batch_size);
// note: GetInstance() is not thread-safe
// we assume PullDenseWorker has been already initialized in DistMultiTrainer
auto pull_dense_worker = framework::PullDenseWorker::GetInstance();
PADDLE_ENFORCE_NE(pull_dense_worker, nullptr,
platform::errors::PreconditionNotMet(
"pull_dense_worker should not be null"));
int thread_id = pull_dense_worker->GetThreadIdByScope(&ctx.scope());
pull_dense_worker->IncreaseThreadVersion(thread_id, table_id);
#endif
}
template <typename T>
class PushDenseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PushDenseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
......@@ -91,6 +91,83 @@ class Hogwild(DeviceWorker):
# just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"])
dense_table_set = set()
program_id = str(id(self._program))
if self._program == None:
print("program of current device worker is not configured")
exit(-1)
opt_info = self._program._fleet_opt
if opt_info is None:
return
program_configs = opt_info["program_configs"]
downpour = trainer_desc.downpour_param
for pid in program_configs:
if pid == program_id:
pc = downpour.program_config.add()
pc.program_id = program_id
for i in program_configs[program_id]["push_sparse"]:
pc.push_sparse_table_id.extend([i])
for i in program_configs[program_id]["push_dense"]:
pc.push_dense_table_id.extend([i])
dense_table_set.add(i)
for i in program_configs[program_id]["pull_sparse"]:
pc.pull_sparse_table_id.extend([i])
for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i])
dense_table_set.add(i)
break
trainer_desc.device_worker_name = "HogwildWorker"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
if opt_info.get("program_id_to_worker") is None:
raise ValueError("opt_info must have program_id_to_worker")
prog_id_to_worker = opt_info["program_id_to_worker"]
if prog_id_to_worker.get(program_id) is None:
raise ValueError("%s not found in program_id_to_worker" %
program_id)
worker = opt_info["program_id_to_worker"][program_id]
for i in worker.get_desc().dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_len = len(worker.get_desc().sparse_table)
for i in range(sparse_len):
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = worker.get_desc().sparse_table[i].table_id
sparse_table.sparse_key_name.extend(worker.get_desc().sparse_table[
i].slot_key)
sparse_table.sparse_value_name.extend(worker.get_desc()
.sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
i].slot_gradient)
sparse_table.fea_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim
# not use emb_dim
sparse_table.emb_dim = -1
# not use hard code click
sparse_table.label_var_name = ""
if opt_info["stat_var_names"]:
for i in opt_info["stat_var_names"]:
downpour.stat_var_names.extend([i])
for i in worker.get_desc().dense_table:
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(worker.get_desc().skip_op)
if self._infer:
downpour.push_dense = False
downpour.push_sparse = False
class DownpourSGD(DeviceWorker):
"""
......
......@@ -577,6 +577,193 @@ class PSLib(Fleet):
fleet = PSLib()
def _prepare_params(input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
"""
preprocess params, this interface is not for users.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable
size(list of int): the embedding dim
is_sparse(bool): whether input is sparse ids
is_distributed(bool): whether in distributed mode
padding_idx(int): padding idx of input
param_attr(ParamAttr): To specify the weight parameter property
dtype(str): data type of output
"""
if param_attr is None:
raise ValueError("param_attr must be set")
name = param_attr.name
if name is None:
raise ValueError("embedding name must be set")
if not isinstance(size, list) and not isinstance(size, tuple):
raise ValueError("embedding size must be list or tuple")
size = size[-1]
global FLEET_GLOBAL_DICT
FLEET_GLOBAL_DICT["enable"] = True
d_table = FLEET_GLOBAL_DICT["emb_to_table"]
d_accessor = FLEET_GLOBAL_DICT["emb_to_accessor"]
d_size = FLEET_GLOBAL_DICT["emb_to_size"]
# check embedding size
if d_size.get(name) is None:
d_size[name] = size
elif d_size[name] != size:
raise ValueError("embedding size error: %s vs %s" %
(size, d_size[name]))
# check embedding accessor
accessor = FLEET_GLOBAL_DICT["cur_accessor"]
if d_accessor.get(name) is None:
d_accessor[name] = accessor
elif d_accessor[name] != accessor:
raise ValueError("embedding size error: %s vs %s" %
(d_accessor[name], accessor))
# check embedding table id
if d_table.get(name) is None:
d_table[name] = FLEET_GLOBAL_DICT["cur_sparse_id"]
FLEET_GLOBAL_DICT["cur_sparse_id"] += 1
# check other params
if not is_sparse:
raise ValueError("is_sparse must be True")
elif not is_distributed:
raise ValueError("is_distributed must be True")
elif dtype != "float32":
raise ValueError("dtype must be float32")
def _fleet_embedding(input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
"""
add fleet embedding, this interface is not for users.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable
size(list of int): the embedding dim
is_sparse(bool): whether input is sparse ids
is_distributed(bool): whether in distributed mode
padding_idx(int): padding idx of input
param_attr(ParamAttr): To specify the weight parameter property
dtype(str): data type of output
"""
# check and set params
_prepare_params(input, size, is_sparse, is_distributed, padding_idx,
param_attr, dtype)
name = param_attr.name
size = size[-1]
if padding_idx is None:
padding_idx = 0
global FLEET_GLOBAL_DICT
return fluid.layers.nn._pull_sparse(
input=input,
size=size,
table_id=FLEET_GLOBAL_DICT["emb_to_table"][name],
accessor_class=FLEET_GLOBAL_DICT["emb_to_accessor"][name],
name=name,
ctr_label_name=FLEET_GLOBAL_DICT["click_name"],
padding_id=padding_idx,
dtype=dtype,
scale_sparse_grad=FLEET_GLOBAL_DICT["scale_sparse_grad"])
def _fleet_embedding_v2(input,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32'):
"""
add fleet embedding v2, this interface is not for users.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable
size(list of int): the embedding dim
is_sparse(bool): whether input is sparse ids
is_distributed(bool): whether in distributed mode
padding_idx(int): padding idx of input
param_attr(ParamAttr): To specify the weight parameter property
dtype(str): data type of output
"""
# check and set params
_prepare_params(input, size, is_sparse, is_distributed, padding_idx,
param_attr, dtype)
name = param_attr.name
size = size[-1]
if padding_idx is None:
padding_idx = 0
return fluid.layers.nn._pull_sparse_v2(
input=input,
size=size,
table_id=FLEET_GLOBAL_DICT["emb_to_table"][name],
accessor_class=FLEET_GLOBAL_DICT["emb_to_accessor"][name],
name=name,
ctr_label_name=FLEET_GLOBAL_DICT["click_name"],
padding_id=padding_idx,
dtype=dtype,
scale_sparse_grad=FLEET_GLOBAL_DICT["scale_sparse_grad"])
class fleet_embedding(object):
"""
fleet embedding class, it is used as a wrapper
Example:
.. code-block:: python
with fleet_embedding(click_name=label.name):
emb = fluid.layers.embedding(
input=var,
size=[-1, 11],
is_sparse=True,
is_distributed=True,
param_attr=fluid.ParamAttr(name="embedding"))
"""
def __init__(self, click_name, scale_sparse_grad=True):
"""Init."""
self.origin_emb = fluid.layers.embedding
self.origin_emb_v2 = fluid.embedding
# if user uses cvm layer after embedding, click_name can be None
self.click_name = "" if click_name is None else click_name
self.scale_sparse_grad = scale_sparse_grad
# it's default value, will be modified in minimize
self.accessor = "DownpourCtrAccessor"
def __enter__(self):
"""Enter."""
fluid.layers.embedding = _fleet_embedding
fluid.embedding = _fleet_embedding_v2
FLEET_GLOBAL_DICT["cur_accessor"] = self.accessor
FLEET_GLOBAL_DICT["click_name"] = self.click_name
FLEET_GLOBAL_DICT["scale_sparse_grad"] = self.scale_sparse_grad
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit."""
fluid.layers.embedding = self.origin_emb
fluid.embedding = self.origin_emb_v2
FLEET_GLOBAL_DICT["cur_accessor"] = ""
FLEET_GLOBAL_DICT["click_name"] = ""
FLEET_GLOBAL_DICT["scale_sparse_grad"] = None
class DownpourOptimizer(DistributedOptimizer):
"""
DistributedOptimizer is a wrapper for paddle.fluid.optimizer
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Optimizer Factory."""
__all__ = ["DistributedAdam"]
__all__ = ["DistributedAdam", "FLEET_GLOBAL_DICT"]
import paddle.fluid as fluid
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table_inputs
......@@ -23,6 +23,20 @@ from collections import OrderedDict
from .node import DownpourWorker, DownpourServer
from . import ps_pb2 as pslib
# this dict is for store info about pull/push sparse ops.
FLEET_GLOBAL_DICT = {
# global settings
"enable": False,
"emb_to_table": {},
"emb_to_accessor": {},
"emb_to_size": {},
# current embedding settings
"cur_sparse_id": 0,
"cur_accessor": "",
"click_name": "",
"scale_sparse_grad": None,
}
class DistributedOptimizerImplBase(object):
"""
......@@ -67,6 +81,12 @@ class DistributedAdam(DistributedOptimizerImplBase):
".batch_size", ".batch_square_sum", ".batch_sum",
".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD"
]
self.supported_embedding_types = [
"lookup_table", "pull_sparse", "pull_sparse_v2"
]
self.supported_embedding_grad_types = [
"lookup_table_grad", "push_sparse", "push_sparse_v2"
]
def _find_distributed_lookup_table_inputs(self, program, table_names):
"""
......@@ -84,7 +104,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
inputs_dict[table_name] = []
for op in program.global_block().ops:
if op.type == "lookup_table":
if op.type in self.supported_embedding_types:
if op.input("W")[0] in table_names:
inputs_dict[op.input("W")[0]].extend(
[local_vars[name] for name in op.input("Ids")])
......@@ -106,7 +126,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
outputs_dict[table_name] = []
for op in program.global_block().ops:
if op.type == "lookup_table":
if op.type in self.supported_embedding_types:
if op.input("W")[0] in table_names:
outputs_dict[op.input("W")[0]].extend(
[local_vars[name] for name in op.output("Out")])
......@@ -119,10 +139,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
grads_dict[table_name] = []
for op in program.global_block().ops:
if op.type == "lookup_table_grad" and op.input("W")[
0] in table_names:
grads_dict[op.input("W")[0]].extend(
[local_vars[name] for name in op.input("Out@GRAD")])
if op.type in self.supported_embedding_grad_types:
if op.input("W")[0] in table_names:
grads_dict[op.input("W")[0]].extend(
[local_vars[name] for name in op.input("Out@GRAD")])
return grads_dict
def _find_multi_distributed_lookup_table(self, losses):
......@@ -135,7 +155,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
ret_list = []
for loss in losses:
for op in loss.block.program.global_block().ops:
if op.type == "lookup_table":
if op.type in self.supported_embedding_types:
if op.attr('is_distributed') is True:
table_name = op.input("W")[0]
if table_name not in table_names:
......@@ -251,6 +271,71 @@ class DistributedAdam(DistributedOptimizerImplBase):
ps_param.trainer_param[idx])
idx += 1
# check config in op defination and fleet config
if FLEET_GLOBAL_DICT["enable"]:
one_slot = None
strategy["device_worker"] = "Hogwild"
emb_to_table = FLEET_GLOBAL_DICT["emb_to_table"]
emb_to_accessor = FLEET_GLOBAL_DICT["emb_to_accessor"]
emb_to_size = FLEET_GLOBAL_DICT["emb_to_size"]
if len(sparse_table_to_index) != len(emb_to_table):
raise ValueError(
"sparse tables from program != sparse tables from op: %s "
"vs %s" % (len(sparse_table_to_index), len(emb_to_table)))
for key in sparse_table_to_index:
if key not in emb_to_table or \
sparse_table_to_index[key] != emb_to_table[key]:
print("sparse_table_to_index ", sparse_table_to_index)
print("emb_to_table ", emb_to_table)
raise ValueError("key error: %s" % key)
if strategy.get(key) is None:
strategy[key] = dict()
st = strategy[key]
accessor = None
if st.get("sparse_accessor_class") is not None:
accessor = st["sparse_accessor_class"]
tables = \
server.get_desc().downpour_server_param.downpour_table_param
for table in tables:
if table.table_id == sparse_table_to_index[key]:
accessor = table.accessor.accessor_class
break
for loss in losses:
for op in loss.block.program.global_block().ops:
if op.type in self.supported_embedding_types:
if accessor is not None \
and op.has_attr("AccessorClass"):
op._set_attr("AccessorClass", accessor)
if one_slot is None:
one_slot = loss.block.program.\
global_block().var(op.input("Ids")[0])
# if accessor is None, use default accessor in op definition
if accessor is None:
accessor = emb_to_accessor[key]
# set sparse_embedx_dim in strategy,
# user do not have to set it in config_fleet
if accessor == "DownpourFeatureValueAccessor" \
or accessor == "DownpourCtrAccessor" \
or accessor == "DownpourUnitAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[key] - 3:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size - 3 = %s" %
(st["sparse_embedx_dim"],
emb_to_size[key] - 3))
st["sparse_embedx_dim"] = emb_to_size[key] - 3
elif accessor == "DownpourSparseValueAccessor":
if st.get("sparse_embedx_dim") is not None \
and st["sparse_embedx_dim"] != emb_to_size[key]:
raise ValueError("fleet config sparse_embedx_dim=%s not"
" equal to embedding size = %s" %
(st["sparse_embedx_dim"],
emb_to_size[key]))
st["sparse_embedx_dim"] = emb_to_size[key]
# ServerParameter add all sparse tables
for tn in sparse_table_to_index:
sparse_table_index = sparse_table_to_index[tn]
......@@ -328,6 +413,19 @@ class DistributedAdam(DistributedOptimizerImplBase):
worker.add_dense_table(
dense_table_index, self._learning_rate, params, grads,
dense_start_table_id, sparse_table_names)
if FLEET_GLOBAL_DICT["enable"]:
cur_prog = losses[loss_index].block.program
cur_prog.global_block().append_op(
type="push_dense",
inputs={"Ids": one_slot},
attrs={
"InputNames": [i.name for i in grads],
"TableId": dense_table_index,
"ScaleDataNorm":
strategy.get("scale_datanorm", -1)
})
if "pull_dense" in program_configs[
program_id] and "push_dense" in program_configs[
program_id] and len(program_configs[program_id][
......@@ -358,6 +456,20 @@ class DistributedAdam(DistributedOptimizerImplBase):
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
dense_start_table_id, sparse_table_names)
if FLEET_GLOBAL_DICT["enable"]:
cur_prog = losses[loss_index].block.program
cur_prog.global_block().append_op(
type="push_dense",
inputs={"Ids": one_slot},
attrs={
"InputNames":
[i.name for i in data_norm_grads],
"TableId": dense_table_index,
"ScaleDataNorm":
strategy.get("scale_datanorm", -1)
})
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
......
......@@ -497,6 +497,148 @@ def embedding(input,
return tmp
def _pull_sparse(input,
size,
table_id,
accessor_class,
name="embedding",
ctr_label_name="",
padding_id=0,
dtype='float32',
scale_sparse_grad=True):
"""
**Pull Fleet Sparse Layer**
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
Fleet lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable, which
contains the IDs information.
size(int): The embedding size parameter, which indicates the size of
each embedding vector respectively.
table_id(int): the fleet table id of this embedding.
accessor_class(str): the pslib accessor of the table, default is DownpourCtrAccessor.
ctr_label_name(str): the layer name of click.
padding_id(int): the padding id during lookup, default is 0.
dtype(str): The dtype refers to the data type of output tensor. Only supports
float32 now.
scale_sparse_grad(bool): whether to scale sparse gradient with batch size. default
is True.
Returns:
Variable|list of Variable: The tensor variable storing the embeddings of the \
supplied inputs.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1)
emb = fluid.layers.nn._pull_sparse(
input=data, size=11, table_id=0, accessor_class="DownpourCtrAccessor")
"""
helper = LayerHelper(name, **locals())
inputs = helper.multiple_input()
outs = [helper.create_variable_for_type_inference(dtype)]
input_names = [i.name for i in inputs]
attrs = {
'EmbeddingDim': size,
'TableId': table_id,
'AccessorClass': accessor_class,
'CtrLabelName': ctr_label_name,
'PaddingId': padding_id,
'ScaleSparseGrad': scale_sparse_grad,
'InputNames': input_names,
# this is only for compatible with embedding op
'is_distributed': True
}
# this is only for compatible with embedding op
w, _ = helper.create_or_get_global_variable(
name=name, shape=[size], dtype=dtype, is_bias=False, persistable=True)
helper.append_op(
type='pull_sparse',
inputs={'Ids': inputs,
'W': w},
outputs={'Out': outs},
attrs=attrs)
if len(outs) == 1:
return outs[0]
return outs
def _pull_sparse_v2(input,
size,
table_id,
accessor_class,
name="embedding",
ctr_label_name="",
padding_id=0,
dtype='float32',
scale_sparse_grad=True):
"""
**Pull Fleet Sparse Layer**
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
Fleet lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable, which
contains the IDs information.
size(int): The embedding size parameter, which indicates the size of
each embedding vector respectively.
table_id(int): the pslib table id of this embedding.
accessor_class(str): the fleet accessor of the table, default is DownpourCtrAccessor.
ctr_label_name(str): the layer name of click.
padding_id(int): the padding id during lookup, default is 0.
dtype(str): The dtype refers to the data type of output tensor. Only supports
float32 now.
scale_sparse_grad(bool): whether to scale sparse gradient with batch size. default
is True.
Returns:
Variable|list of Variable: The tensor variable storing the embeddings of the \
supplied inputs.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1)
emb = fluid.layers.nn._pull_sparse_v2(
input=data, size=11, table_id=0, accessor_class="DownpourCtrAccessor")
"""
helper = LayerHelper(name, **locals())
inputs = helper.multiple_input()
outs = [helper.create_variable_for_type_inference(dtype)]
input_names = [i.name for i in inputs]
attrs = {
'EmbeddingDim': size,
'TableId': table_id,
'AccessorClass': accessor_class,
'CtrLabelName': ctr_label_name,
'PaddingId': padding_id,
'ScaleSparseGrad': scale_sparse_grad,
'InputNames': input_names,
# this is only for compatible with embedding op
'is_distributed': True
}
# this is only for compatible with embedding op
w, _ = helper.create_or_get_global_variable(
name=name, shape=[size], dtype=dtype, is_bias=False, persistable=True)
helper.append_op(
type='pull_sparse_v2',
inputs={'Ids': inputs,
'W': w},
outputs={'Out': outs},
attrs=attrs)
if len(outs) == 1:
return outs[0]
return outs
def _pull_box_sparse(input, size, dtype='float32'):
"""
**Pull Box Sparse Layer**
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test fleet."""
from __future__ import print_function
import os
import unittest
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
class TestFleet2(unittest.TestCase):
"""Test cases for fleet ops."""
def setUp(self):
"""Set up, set envs."""
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ[
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001"
def test_pslib_1(self):
"""Test cases for pslib."""
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.parameter_server.pslib import \
fleet_embedding, _prepare_params, _fleet_embedding, \
_fleet_embedding_v2, FLEET_GLOBAL_DICT
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
try:
import netifaces
except:
print("warning: no netifaces, skip test_pslib_1")
return
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36002"
os.environ["PADDLE_TRAINER_ID"] = "0"
role_maker = GeneralRoleMaker()
role_maker.generate_role()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
fleet.init(role_maker)
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
global FLEET_GLOBAL_DICT
with fluid.program_guard(train_program, startup_program):
show = fluid.layers.data(name="show", shape=[-1, 1], \
dtype="int64", lod_level=1, append_batch_size=False)
click = fluid.layers.data(name="click", shape=[-1, 1], \
dtype="int64", lod_level=1, append_batch_size=False)
with fleet_embedding(click_name=click.name):
emb = fluid.layers.embedding(input=show, size=[1, 1], \
is_sparse=True, is_distributed=True, \
param_attr=fluid.ParamAttr(name="embedding"))
emb = fluid.layers.data_norm(
input=emb,
name="a",
epsilon=1e-4,
param_attr={
"batch_size": 1e4,
"batch_sum_default": 0.0,
"batch_square": 1e4
})
fc = fluid.layers.fc(input=emb, size=1, act=None)
label = fluid.layers.data(name="click", shape=[-1, 1], \
dtype="int64", lod_level=1, append_batch_size=False)
label_cast = fluid.layers.cast(label, dtype='float32')
cost = fluid.layers.log_loss(fc, label_cast)
try:
adam = fluid.optimizer.Adam(learning_rate=0.000005)
adam = fleet.distributed_optimizer(
adam,
strategy={
"embedding": {
"sparse_accessor_class": "DownpourSparseValueAccessor"
}
})
adam.minimize([cost], [scope])
except:
print("do not support pslib test, skip")
return
FLEET_GLOBAL_DICT["cur_accessor"] = "DownpourCtrAccessor"
try:
_prepare_params(input=show, size=[1, 1])
except:
print("catch expected exception of param_attr=None")
try:
_prepare_params(
input=show, size=[1, 1], param_attr=fluid.ParamAttr())
except:
print("catch expected exception of name=None")
try:
tmp = fluid.ParamAttr(name="embedding")
_prepare_params(input=show, size=1, param_attr=tmp)
except:
print("catch expected exception of size not list")
try:
tmp = fluid.ParamAttr(name="embedding")
_prepare_params(input=show, size=[-1, 12], param_attr=tmp)
except:
print("catch expected exception of size not equal")
try:
tmp = fluid.ParamAttr(name="embedding")
_prepare_params(
input=show, size=[-1, 1], param_attr=tmp, is_sparse=False)
except:
print("catch expected exception of is_sparse=False")
try:
tmp = fluid.ParamAttr(name="embedding")
_prepare_params(input=show, size=[-1, 1], param_attr=tmp, \
is_sparse=True, is_distributed=False)
except:
print("catch expected exception of is_distributed=False")
try:
_prepare_params(input=show, size=[-1, 1], \
param_attr=fluid.ParamAttr(name="embedding"), \
is_sparse=True, is_distributed=True, dtype="abc")
except:
print("catch expected exception of unknown dtype")
try:
FLEET_GLOBAL_DICT["emb_to_accessor"]["embedding"] = "unknown"
tmp = fluid.ParamAttr(name="embedding")
_prepare_params(input=show, size=[-1, 1], param_attr=tmp)
except:
print("catch expected exception of unknown accessor")
FLEET_GLOBAL_DICT["cur_accessor"] = "DownpourCtrAccessor"
try:
_fleet_embedding(input=show, size=[-1, 1], is_sparse=True, \
is_distributed=True, dtype="float32", \
param_attr=fluid.ParamAttr(name="embedding"))
except:
print("catch expected exception of unknown accessor")
try:
_fleet_embedding_v2(input=show, size=[-1, 1], is_sparse=True, \
is_distributed=True, dtype="float32", \
param_attr=fluid.ParamAttr(name="embedding"))
except:
print("catch expected exception of unknown accessor")
adam1 = fluid.optimizer.Adam(learning_rate=0.000005)
adam1 = fleet.distributed_optimizer(
adam1,
strategy={
"embedding": {
"sparse_accessor_class": "DownpourSparseValueAccessor"
}
})
try:
pre = FLEET_GLOBAL_DICT["emb_to_table"]
FLEET_GLOBAL_DICT["emb_to_table"] = {}
adam1.minimize([cost], [scope])
except:
FLEET_GLOBAL_DICT["emb_to_table"] = pre
print("catch expected exception of empty emb_to_table")
try:
pre = FLEET_GLOBAL_DICT["emb_to_table"]
FLEET_GLOBAL_DICT["emb_to_table"] = {}
FLEET_GLOBAL_DICT["emb_to_table"]["emb1"] = 0
adam1.minimize([cost], [scope])
except:
FLEET_GLOBAL_DICT["emb_to_table"] = pre
print("catch expected exception of error emb_to_table")
try:
adam2 = fluid.optimizer.Adam(learning_rate=0.000005)
adam2 = fleet.distributed_optimizer(adam2)
adam2.supported_embedding_types = []
adam2.minimize([cost], [scope])
except:
print("catch expected exception of embedding_types")
try:
adam3 = fluid.optimizer.Adam(learning_rate=0.000005)
adam3 = fleet.distributed_optimizer(
adam3,
strategy={
"embedding": {
"sparse_accessor_class": "DownpourSparseValueAccessor",
"sparse_embedx_dim": 999
}
})
adam3.minimize([cost], [scope])
except:
print("catch expected exception of embedx_dim error")
try:
adam4 = fluid.optimizer.Adam(learning_rate=0.000005)
adam4 = fleet.distributed_optimizer(
adam4,
strategy={
"embedding": {
"sparse_accessor_class": "DownpourCtrAccessor",
"sparse_embedx_dim": 999
}
})
adam4.minimize([cost], [scope])
except:
print("catch expected exception of embedx_dim error")
train_program1 = fluid.Program()
startup_program1 = fluid.Program()
FLEET_GLOBAL_DICT["emb_to_accessor"] = {}
with fluid.program_guard(train_program1, startup_program1):
show = fluid.layers.data(name="show", shape=[-1, 1], \
dtype="int64", lod_level=1, append_batch_size=False)
with fleet_embedding(click_name=click.name):
emb = fluid.layers.embedding(input=show, size=[1, 1], \
is_sparse=True, is_distributed=True, \
param_attr=fluid.ParamAttr(name="embedding"))
with fleet_embedding(click_name=click.name):
emb1 = fluid.embedding(input=show, size=[1, 1], \
is_sparse=True, is_distributed=True, \
param_attr=fluid.ParamAttr(name="embedding"))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test fleet."""
from __future__ import print_function
import os
import paddle.fluid as fluid
import unittest
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.parameter_server.pslib import \
fleet_embedding, _prepare_params, _fleet_embedding, \
_fleet_embedding_v2, FLEET_GLOBAL_DICT
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
class TestFleet2(unittest.TestCase):
"""Test cases for fleet ops."""
def test_in_memory_dataset_run_fleet(self):
"""
Testcase for InMemoryDataset from create to run.
"""
with open("test_in_memory_dataset_run_fleet_a.txt", "w") as f:
data = "1 1 1 2 2 3 3 4 5 5 5 5 1 1\n"
data += "1 0 1 3 2 3 4 4 6 6 6 6 1 2\n"
data += "1 1 1 4 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset_run_fleet_b.txt", "w") as f:
data = "1 0 1 5 2 3 3 4 5 5 5 5 1 4\n"
data += "1 1 1 6 2 3 4 4 6 6 6 6 1 5\n"
data += "1 0 1 7 2 3 5 4 7 7 7 7 1 6\n"
data += "1 1 1 8 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["click", "slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
click = slots_vars[0]
embs = []
for slot in slots_vars[1:3]:
with fleet_embedding(click_name=click.name):
emb = fluid.layers.embedding(input=slot, size=[-1, 11], \
is_sparse=True, is_distributed=True, \
param_attr=fluid.ParamAttr(name="embedding"))
embs.append(emb)
for slot in slots_vars[3:5]:
with fleet_embedding(click_name=click.name):
emb = fluid.embedding(input=slot, size=[-1, 11], \
is_sparse=True, is_distributed=True, \
param_attr=fluid.ParamAttr(name="embedding"))
emb = fluid.layers.reshape(emb, [-1, 11])
embs.append(emb)
concat = fluid.layers.concat([embs[0], embs[3]], axis=1)
fc = fluid.layers.fc(input=concat, size=1, act=None)
label_cast = fluid.layers.cast(slots_vars[1], dtype='float32')
cost = fluid.layers.log_loss(fc, label_cast)
cost = fluid.layers.mean(cost)
try:
fleet.init()
adam = fluid.optimizer.Adam(learning_rate=0.000005)
adam = fleet.distributed_optimizer(adam)
scope = fluid.Scope()
adam.minimize([cost], [scope])
except:
print("do not support pslib test, skip")
return
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(1)
dataset.set_thread(2)
dataset.set_filelist([
"test_in_memory_dataset_run_fleet_a.txt",
"test_in_memory_dataset_run_fleet_b.txt"
])
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
exe.train_from_dataset(fluid.default_main_program(), dataset)
fleet._opt_info["stat_var_names"] = ["233"]
exe.infer_from_dataset(fluid.default_main_program(), dataset)
fleet._opt_info = None
fleet._fleet_ptr = None
os.remove("./test_in_memory_dataset_run_fleet_a.txt")
os.remove("./test_in_memory_dataset_run_fleet_b.txt")
if __name__ == "__main__":
unittest.main()
......@@ -230,6 +230,7 @@ class MultiTrainer(TrainerDesc):
super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self._device_worker._set_infer(self._infer)
self._device_worker._set_program(self._program)
self._device_worker._gen_worker_desc(self.proto_desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册