未验证 提交 a881b4d5 编写于 作者: T tangwei12 提交者: GitHub

Struct SparseValue && Bug Fix (#31721)

* add PullSparseValue for pull sparse

* fix bug for PullSparseValue

* add test mode in lookuptable

* revert API change

* add comment for is_training
上级 b8b82b72
......@@ -146,41 +146,6 @@ void FleetWrapper::CreateClient2ClientConnection() {
client2client_max_retry_);
}
std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
return pserver_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
......@@ -224,8 +189,10 @@ void FleetWrapper::PullSparseVarsSync(
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
bool training = true;
auto status = pserver_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(),
training);
pull_sparse_status.push_back(std::move(status));
for (auto& t : pull_sparse_status) {
t.wait();
......@@ -238,9 +205,13 @@ void FleetWrapper::PullSparseVarsSync(
}
}
// is_training is true means training, false means inference, the behavior is
// different on pserver
void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id,
platform::Place place,
bool is_training,
std::vector<const LoDTensor*>* inputs,
std::vector<LoDTensor*>* outputs) {
std::vector<uint64_t> fea_keys;
......@@ -279,7 +250,8 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
}
auto* communicator = Communicator::GetInstance();
auto status = communicator->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size());
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size(),
is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
......
......@@ -84,19 +84,14 @@ class FleetWrapper {
int fea_dim,
const std::vector<std::string>& var_emb_names);
// Pull sparse variables from server in async mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values std::future
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// Pull sparse variables from server in sync mode
// pull immediately to tensors
// is_training is true means training, false means inference, the behavior is
// different on pserver
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
bool is_training,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<LoDTensor*>* outputs); // NOLINT
......
......@@ -768,8 +768,8 @@ std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) {
const uint64_t *keys, size_t num,
bool is_training) {
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
......@@ -837,16 +837,27 @@ std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
request_buffer.append((void *)&is_training, sizeof(bool));
std::vector<uint32_t> keys_counter;
keys_counter.reserve(sorted_kv_size);
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append((void *)&last_key, sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
++keys;
}
keys_counter.push_back(keys);
}
request_buffer.append((void *)keys_counter.data(),
sizeof(uint32_t) * keys_counter.size());
if (kv_request_count == 0) {
closure->Run();
} else {
......@@ -956,7 +967,7 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
}
auto status = pull_sparse((float **)save_vec.data(), table_id,
save_key.data(), save_key.size());
save_key.data(), save_key.size(), true);
status.wait();
// create lod tensor
......
......@@ -148,7 +148,8 @@ class BrpcPsClient : public PSClient {
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num);
const uint64_t *keys, size_t num,
bool is_training);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -337,33 +338,39 @@ int32_t BrpcPsService::pull_sparse(Table *table,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_sparse");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
push_sparse_request_buffer.resize(0);
push_sparse_request_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_sparse_request_buffer.data()), req_buffer_size);
/*
Attachment Content:
|---keysData---|
|---8*{num}B---|
*/
const uint64_t *keys = (const uint64_t *)data;
auto dim = table->value_accesor()->select_dim();
thread_local std::string req_buffer;
req_buffer.reserve(req_buffer_size);
const void *data = cntl->request_attachment().fetch(
const_cast<char *>(req_buffer.data()), req_buffer_size);
auto value = PullSparseValue(num, dim);
value.DeserializeFromBytes(const_cast<void *>(data));
std::vector<float> res_data;
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_sparse(res_data.data(), keys, num);
res_data.resize(num * dim);
table->pull_sparse(res_data.data(), value);
cntl->response_attachment().append((char *)res_data.data(),
res_data.size() * sizeof(float));
return 0;
......
......@@ -320,9 +320,11 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
push_g_vec.push_back(tensor->data<float>() + i * dim);
}
bool training = true;
auto status = _worker_ptr->pull_sparse(
(float **)push_g_vec.data(), table_id, // NOLINT
sparse_push_keys.data(), sparse_push_keys.size());
sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
return;
}
......
......@@ -112,10 +112,11 @@ class PSClient {
// future结束前keys和values缓冲区不能再次使用
// 整合多个线程请求的keys,聚集并分散发送到server
// 返回结果后,遍历buffer并对values赋值
// is_training 用于区分请求是训练/预测,server端对于特征和准入会有不同的处理.
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys,
size_t num) = 0;
const uint64_t *keys, size_t num,
bool is_training) = 0;
virtual std::future<int32_t> print_table_stat(uint32_t table_id) = 0;
......
......@@ -103,13 +103,16 @@ class GraphTable : public SparseTable {
Node *find_node(uint64_t id);
virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {
virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) {
return 0;
}
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
virtual void clear() {}
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string &param) { return 0; }
......@@ -140,5 +143,5 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
};
}
};
} // namespace distributed
}; // namespace paddle
......@@ -254,7 +254,6 @@ int32_t CommonSparseTable::initialize_value() {
}
auto accessor = _config.accessor();
std::vector<uint64_t> feasigns;
for (size_t x = 0; x < accessor.fea_dim(); ++x) {
......@@ -271,9 +270,14 @@ int32_t CommonSparseTable::initialize_value() {
std::vector<uint64_t> ids(bucket_feasigns);
std::copy(feasigns.begin() + buckets[x], feasigns.begin() + buckets[x + 1],
ids.begin());
std::vector<uint32_t> fres;
fres.resize(ids.size(), 1);
auto pull_value = PullSparseValue(ids, fres, param_dim_);
std::vector<float> pulls;
pulls.resize(bucket_feasigns * param_dim_);
pull_sparse(pulls.data(), ids.data(), bucket_feasigns);
pull_sparse(pulls.data(), pull_value);
}
return 0;
......@@ -399,32 +403,36 @@ int32_t CommonSparseTable::pour() {
return 0;
}
int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys,
size_t num) {
int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
for (int x = 0; x < num; ++x) {
auto y = keys[x] % task_pool_size_;
offset_bucket[y].push_back(x);
}
std::vector<std::future<int>> tasks(task_pool_size_);
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &offset_bucket, &pull_values]() -> int {
[this, shard_id, shard_num, &pull_value, &pull_values]() -> int {
auto& block = shard_values_[shard_id];
auto& offsets = offset_bucket[shard_id];
for (int i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i];
auto id = keys[offset];
auto* value = block->Init(id);
std::copy_n(value + param_offset_, param_dim_,
pull_values + param_dim_ * offset);
std::vector<int> offsets;
pull_value.Fission(shard_id, shard_num, &offsets);
if (pull_value.is_training_) {
for (auto& offset : offsets) {
auto feasign = pull_value.feasigns_[offset];
auto frequencie = pull_value.frequencies_[offset];
auto* value = block->Init(feasign, true, frequencie);
std::copy_n(value + param_offset_, param_dim_,
pull_values + param_dim_ * offset);
}
} else {
for (auto& offset : offsets) {
auto feasign = pull_value.feasigns_[offset];
auto* value = block->Init(feasign, false);
std::copy_n(value + param_offset_, param_dim_,
pull_values + param_dim_ * offset);
}
}
return 0;
......
......@@ -61,8 +61,7 @@ class CommonSparseTable : public SparseTable {
int32_t save(const std::string& path, const std::string& param);
virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* pull_values, const uint64_t* keys,
size_t num);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num);
......
......@@ -98,8 +98,8 @@ class DenseTable : public Table {
virtual ~DenseTable() {}
virtual void *get_shard(size_t shard_idx) { return 0; }
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
......@@ -123,8 +123,8 @@ class BarrierTable : public Table {
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
......
......@@ -155,7 +155,8 @@ class ValueBlock {
}
// pull
float *Init(const uint64_t &id, const bool with_update = true) {
float *Init(const uint64_t &id, const bool with_update = true,
const int counter = 1) {
if (!Has(id)) {
values_[id] = std::make_shared<VALUE>(value_length_);
}
......@@ -163,16 +164,16 @@ class ValueBlock {
auto &value = values_.at(id);
if (with_update) {
AttrUpdate(value);
AttrUpdate(value, counter);
}
return value->data_.data();
}
void AttrUpdate(std::shared_ptr<VALUE> value) {
void AttrUpdate(std::shared_ptr<VALUE> value, const int counter) {
// update state
value->unseen_days_ = 0;
++value->count_;
value->count_ += counter;
if (!value->is_entry_) {
value->is_entry_ = entry_func_(value);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace distributed {
struct PullSparseValue {
explicit PullSparseValue(int numel, int dim)
: numel_(numel),
dim_(dim),
is_training_(true),
feasigns_(nullptr),
frequencies_(nullptr) {}
explicit PullSparseValue(std::vector<uint64_t> feasigns,
std::vector<uint32_t> frequencies, int dim) {
numel_ = feasigns.size();
dim_ = dim;
is_training_ = true;
feasigns_ = feasigns.data();
frequencies_ = frequencies.data();
}
void DeserializeFromBytes(void* bytes) {
/*
|---isTraining--------------|
|---8*{num}B(keysData)------|
|---4*{num}B(Frequencies)---|
*/
auto* begin = reinterpret_cast<char*>(bytes);
is_training_ = reinterpret_cast<bool*>(begin)[0];
feasigns_ = reinterpret_cast<uint64_t*>(begin + sizeof(bool));
frequencies_ = reinterpret_cast<uint32_t*>(begin + sizeof(bool) +
sizeof(uint64_t) * numel_);
}
void Fission(const int shard_id, const int shard_num,
std::vector<int>* offset_shard) const {
offset_shard->reserve(numel_ / shard_num + 1);
for (int x = 0; x < numel_; ++x) {
if (feasigns_[x] % shard_num == shard_id) {
offset_shard->push_back(x);
}
}
}
int numel_;
int dim_;
bool is_training_;
uint64_t* feasigns_;
uint32_t* frequencies_;
};
} // namespace distributed
} // namespace paddle
......@@ -22,8 +22,17 @@ int32_t SparseGeoTable::pull_geo_param(const uint32_t trainer_id,
std::vector<uint64_t>* ids) {
geo_recorder->GetAndClear(trainer_id, ids);
auto dim = _config.common().dims()[0];
std::vector<uint32_t> frequencies;
frequencies.resize(ids->size(), 1);
auto pull_value = PullSparseValue(ids->size(), dim);
pull_value.is_training_ = true;
pull_value.feasigns_ = ids->data();
pull_value.frequencies_ = frequencies.data();
values->resize(ids->size() * dim);
CommonSparseTable::pull_sparse(values->data(), ids->data(), ids->size());
CommonSparseTable::pull_sparse(values->data(), pull_value);
return 0;
}
......
......@@ -21,6 +21,7 @@
#include <string>
#include <utility>
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/graph_node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -47,8 +48,8 @@ class Table {
return 0;
}
virtual int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) = 0;
virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) = 0;
virtual int32_t push_sparse_param(const uint64_t *keys, const float *values,
......
......@@ -52,8 +52,8 @@ class TensorTable : public Table {
int32_t push_dense(const float *values, size_t num) override { return 0; }
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
......@@ -102,8 +102,8 @@ class DenseTensorTable : public TensorTable {
DenseTensorTable() {}
virtual ~DenseTensorTable() {}
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
......@@ -158,8 +158,8 @@ class GlobalStepTable : public DenseTensorTable {
GlobalStepTable() {}
virtual ~GlobalStepTable() {}
int32_t pull_sparse(float *values, const uint64_t *keys,
size_t num) override {
int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t push_sparse(const uint64_t *keys, const float *values,
......
......@@ -212,8 +212,8 @@ void RunBrpcPushSparse() {
/*-----------------------Test Server Init----------------------------------*/
LOG(INFO) << "Run pull_sparse_param";
auto pull_status = worker_ptr_->pull_sparse(fea_value_ptr.data(), 0,
fea_keys.data(), fea_keys.size());
auto pull_status = worker_ptr_->pull_sparse(
fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
fea_values.data()[idx] *= 2.0;
......@@ -241,7 +241,7 @@ void RunBrpcPushSparse() {
push_status.wait();
auto pull_param_status = worker_ptr_->pull_sparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size());
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_param_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
......@@ -275,7 +275,7 @@ void RunBrpcPushSparse() {
push_grad_status.wait();
auto pull_update_status = worker_ptr_->pull_sparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size());
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_update_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/common_dense_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h"
#include "paddle/fluid/distributed/table/table.h"
......@@ -53,14 +54,18 @@ TEST(SparseGeoTable, SSUM) {
// test push_sparse_param, and create params
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<uint32_t> init_fres = {1, 1, 1, 1, 1};
std::vector<float> init_values;
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
init_values.push_back(0.0);
}
table->push_sparse_param(init_keys.data(), init_values.data(),
init_keys.size());
std::vector<float> pull_values(init_values.size());
table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5);
}
......
......@@ -55,9 +55,14 @@ TEST(CommonSparseTable, SGD) {
// pull parameters for create and check
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<uint32_t> init_fres = {1, 1, 1, 1, 1};
std::vector<float> init_values;
init_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(init_values.data(), init_keys.data(), init_keys.size());
std::vector<float> pull_values(init_values.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(init_values.data(), value);
// for check
std::vector<float> total_gradients;
......@@ -100,7 +105,8 @@ TEST(CommonSparseTable, SGD) {
std::vector<float> pull_values;
pull_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size());
table->pull_sparse(init_values.data(), value);
for (size_t i = 0; i < init_values.size(); ++i) {
auto update_val = init_values[i] - 1.0 * total_gradients[i];
ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-5);
......@@ -148,9 +154,13 @@ TEST(CommonSparseTable, Adam) {
// pull parameters for create and check
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<uint32_t> init_fres = {1, 1, 1, 1, 1};
std::vector<float> init_values;
init_values.resize(init_keys.size() * emb_dim);
table->pull_sparse(init_values.data(), init_keys.data(), init_keys.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(init_values.data(), value);
// push gradient
std::vector<std::vector<uint64_t>> trainer_keys;
......
......@@ -119,6 +119,11 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Output data type")
.SetDefault(framework::proto::VarType::FP32);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training.")
.SetDefault(false);
AddComment(R"DOC(
Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W,
......
......@@ -30,6 +30,7 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
auto padding_idx = context.Attr<int64_t>("padding_idx");
auto table_id = context.Attr<int>("table_id");
bool is_test = context.Attr<bool>("is_test");
auto embedding_name = context.InputNames("W").front();
int64_t emb_dim = 0;
......@@ -55,7 +56,8 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(context.GetPlace())) {
fleet->PullSparseToTensorSync(static_cast<uint64_t>(table_id), emb_dim,
static_cast<uint64_t>(padding_idx),
context.GetPlace(), &inputs, &outputs);
context.GetPlace(), !is_test, &inputs,
&outputs);
} else {
auto inputs_variable = context.MultiInputVar("Ids");
auto outputs_variable = context.MultiOutputVar("Outputs");
......@@ -93,7 +95,8 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
// use fleet->PullSparse
fleet->PullSparseToTensorSync(static_cast<uint64_t>(table_id), emb_dim,
static_cast<uint64_t>(padding_idx),
cpu_place, &tmp_input_vec, &tmp_output_vec);
cpu_place, !is_test, &tmp_input_vec,
&tmp_output_vec);
// cp temp to origin
for (size_t idx = 0; idx < output_var_size; ++idx) {
......
......@@ -16,6 +16,7 @@
import numpy as np
import os
import paddle
import warnings
class DistributedInfer:
......@@ -104,8 +105,6 @@ class DistributedInfer:
vars=need_load_vars)
def get_dist_infer_program(self):
import paddle.distributed.fleet as fleet
varname2tables = self._get_sparse_table_map()
convert_program = self._convert_program(self.origin_main_program,
varname2tables)
......@@ -185,6 +184,7 @@ class DistributedInfer:
"is_distributed": is_distributed,
"padding_idx": padding_idx,
"table_id": table_id,
"is_test": True,
"lookup_table_version": op_type
})
else:
......@@ -193,6 +193,9 @@ class DistributedInfer:
)
pull_sparse_ops = _get_pull_sparse_ops(program)
warnings.warn(
"lookup_table will be forced to test mode when use DistributedInfer"
)
_pull_sparse_fuse(program, pull_sparse_ops)
return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册