提交 6ebf5b97 编写于 作者: Y yangqingyou

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add-crypto-api

#### Required(必填, multiple choices, two at most)
- **PR type(PR 类型) is ( ):**
A. New features(新功能)---------------- D. Performance optimization(性能优化)
B. Bug fixes(问题修复)------------------ E. Breaking changes(向后不兼容的改变)
C. Function optimization(功能优化)------F. Others(其它)
- **PR changes(改动点)is ( ):**
A. OPs(operators)---------------------- C. Docs(文档)
B. APIs(接口)--------------------------- D. Others(其它)
- **Use one sentence to describe what this PR does.(简述本次PR的目的和改动)**
-----------------------
#### Optional(选填, If None, please delete it)
- **Describe what this PR does in detail. If this PR fixes an issue, please give the issue id.**
<!-- DESCRIBE THE BUG OR REQUIREMENT HERE. eg. #2020(格式为 #Issue编号)-->
- **If you modified docs, please make sure that both Chinese and English docs were modified and provide a preview screenshot. (文档必填)**
<!-- ADD SCREENSHOT HERE IF APPLICABLE. -->
- **Please write down other information you want to tell reviewers.**
<!-- Demo: PR types: Bug fixes, Function optimization -->
<!-- One of [ New features | Bug fixes | Function optimization | Performance optimization | Breaking changes | Others ] -->
PR types:
<!-- Demo: PR changes: OPs -->
<!-- One of [ OPs | APIs | Docs | Others ] -->
PR changes:
<!-- Describe what this PR does -->
Describe:
......@@ -41,44 +41,44 @@ namespace paddle {
namespace framework {
void RecordCandidateList::ReSize(size_t length) {
_mutex.lock();
_capacity = length;
CHECK(_capacity > 0); // NOLINT
_candidate_list.clear();
_candidate_list.resize(_capacity);
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
mutex_.lock();
capacity_ = length;
CHECK(capacity_ > 0); // NOLINT
candidate_list_.clear();
candidate_list_.resize(capacity_);
full_ = false;
cur_size_ = 0;
total_size_ = 0;
mutex_.unlock();
}
void RecordCandidateList::ReInit() {
_mutex.lock();
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
mutex_.lock();
full_ = false;
cur_size_ = 0;
total_size_ = 0;
mutex_.unlock();
}
void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) {
_mutex.lock();
mutex_.lock();
size_t index = 0;
++_total_size;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) {
_candidate_list[_cur_size++] = record;
_full = (_cur_size == _capacity);
if (!full_) {
candidate_list_[cur_size_++] = record;
full_ = (cur_size_ == capacity_);
} else {
CHECK(_cur_size == _capacity);
index = fleet_ptr->LocalRandomEngine()() % _total_size;
if (index < _capacity) {
_candidate_list[index] = record;
CHECK(cur_size_ == capacity_);
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_[index] = record;
}
}
index = fleet_ptr->LocalRandomEngine()() % _cur_size;
*result = _candidate_list[index];
_mutex.unlock();
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
*result = candidate_list_[index];
mutex_.unlock();
}
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
......@@ -1452,7 +1452,11 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector<PvInstance>& pv_vec) {
int PaddleBoxDataFeed::GetCurrentPhase() {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
return box_ptr->PassFlag(); // join: 1, update: 0
if (box_ptr->Mode() == 1) { // For AucRunner
return 1;
} else {
return box_ptr->Phase();
}
#else
LOG(WARNING) << "It should be complied with BOX_PS...";
return current_phase_;
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -34,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable.h"
......@@ -484,13 +486,25 @@ paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
struct RecordCandidate {
std::string ins_id_;
std::unordered_multimap<uint16_t, FeatureKey> feas;
std::unordered_multimap<uint16_t, FeatureKey> feas_;
size_t shadow_index_ = -1; // Optimization for Reservoir Sample
RecordCandidate() {}
RecordCandidate(const Record& rec,
const std::unordered_set<uint16_t>& slot_index_to_replace) {
for (const auto& fea : rec.uint64_feasigns_) {
if (slot_index_to_replace.find(fea.slot()) !=
slot_index_to_replace.end()) {
feas_.insert({fea.slot(), fea.sign()});
}
}
}
RecordCandidate& operator=(const Record& rec) {
feas.clear();
feas_.clear();
ins_id_ = rec.ins_id_;
for (auto& fea : rec.uint64_feasigns_) {
feas.insert({fea.slot(), fea.sign()});
feas_.insert({fea.slot(), fea.sign()});
}
return *this;
}
......@@ -499,22 +513,67 @@ struct RecordCandidate {
class RecordCandidateList {
public:
RecordCandidateList() = default;
RecordCandidateList(const RecordCandidateList&) = delete;
RecordCandidateList& operator=(const RecordCandidateList&) = delete;
RecordCandidateList(const RecordCandidateList&) {}
size_t Size() { return cur_size_; }
void ReSize(size_t length);
void ReInit();
void ReInitPass() {
for (size_t i = 0; i < cur_size_; ++i) {
if (candidate_list_[i].shadow_index_ != i) {
candidate_list_[i].ins_id_ =
candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
candidate_list_[i].feas_.swap(
candidate_list_[candidate_list_[i].shadow_index_].feas_);
candidate_list_[i].shadow_index_ = i;
}
}
candidate_list_.resize(cur_size_);
}
void AddAndGet(const Record& record, RecordCandidate* result);
void AddAndGet(const Record& record, size_t& index_result) { // NOLINT
// std::unique_lock<std::mutex> lock(mutex_);
size_t index = 0;
++total_size_;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!full_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_.back().shadow_index_ = cur_size_;
++cur_size_;
full_ = (cur_size_ == capacity_);
} else {
index = fleet_ptr->LocalRandomEngine()() % total_size_;
if (index < capacity_) {
candidate_list_.emplace_back(record, slot_index_to_replace_);
candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
}
}
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
index_result = candidate_list_[index].shadow_index_;
}
const RecordCandidate& Get(size_t index) const {
PADDLE_ENFORCE_LT(
index, candidate_list_.size(),
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
"elements in candidate_list[%lu].",
index, candidate_list_.size()));
return candidate_list_[index];
}
void SetSlotIndexToReplace(
const std::unordered_set<uint16_t>& slot_index_to_replace) {
slot_index_to_replace_ = slot_index_to_replace;
}
private:
size_t _capacity = 0;
std::mutex _mutex;
bool _full = false;
size_t _cur_size = 0;
size_t _total_size = 0;
std::vector<RecordCandidate> _candidate_list;
size_t capacity_ = 0;
std::mutex mutex_;
bool full_ = false;
size_t cur_size_ = 0;
size_t total_size_ = 0;
std::vector<RecordCandidate> candidate_list_;
std::unordered_set<uint16_t> slot_index_to_replace_;
};
template <class AR>
......
......@@ -1141,13 +1141,15 @@ void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}
void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
void MultiSlotDataset::GetRandomData(
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
slots_shuffle_rclist_.ReInit();
for (const auto& rec : slots_shuffle_original_data_) {
const auto& slots_shuffle_original_data = GetSlotsOriginalData();
for (const auto& rec : slots_shuffle_original_data) {
RecordCandidate rand_rec;
Record new_rec = rec;
slots_shuffle_rclist_.AddAndGet(rec, &rand_rec);
......@@ -1161,7 +1163,7 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas.equal_range(slot);
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
......@@ -1173,9 +1175,9 @@ void MultiSlotDataset::GetRandomData(const std::set<uint16_t>& slots_to_replace,
<< " repush feasign num: " << debug_push_cnt;
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
void MultiSlotDataset::PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slots) { // NOLINT
int out_channel_size = 0;
if (cur_channel_ == 0) {
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
......@@ -1189,20 +1191,14 @@ void MultiSlotDataset::SlotsShuffle(
VLOG(2) << "DatasetImpl<T>::SlotsShuffle() begin with input channel size: "
<< input_channel_->Size()
<< " output channel size: " << out_channel_size;
if (!slots_shuffle_fea_eval_) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end,"
"fea eval mode off, need to set on for slots shuffle";
return;
}
if ((!input_channel_ || input_channel_->Size() == 0) &&
slots_shuffle_original_data_.size() == 0 && out_channel_size == 0) {
VLOG(3) << "DatasetImpl<T>::SlotsShuffle() end, no data to slots shuffle";
return;
}
platform::Timer timeline;
timeline.Start();
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::set<uint16_t> index_slots;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
std::string cur_slot = multi_slot_desc.slots(i).name();
if (slots_to_replace.find(cur_slot) != slots_to_replace.end()) {
......@@ -1287,6 +1283,19 @@ void MultiSlotDataset::SlotsShuffle(
}
CHECK(input_channel_->Size() == 0)
<< "input channel should be empty before slots shuffle";
}
// slots shuffle to input_channel_ with needed-shuffle slots
void MultiSlotDataset::SlotsShuffle(
const std::set<std::string>& slots_to_replace) {
PADDLE_ENFORCE_EQ(slots_shuffle_fea_eval_, true,
platform::errors::PreconditionNotMet(
"fea eval mode off, need to set on for slots shuffle"));
platform::Timer timeline;
timeline.Start();
std::unordered_set<uint16_t> index_slots;
PreprocessChannel(slots_to_replace, index_slots);
std::vector<Record> random_data;
random_data.clear();
// get slots shuffled random_data
......
......@@ -67,6 +67,7 @@ class Dataset {
virtual void SetParseContent(bool parse_content) = 0;
virtual void SetParseLogKey(bool parse_logkey) = 0;
virtual void SetEnablePvMerge(bool enable_pv_merge) = 0;
virtual bool EnablePvMerge() = 0;
virtual void SetMergeBySid(bool is_merge) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
......@@ -108,10 +109,7 @@ class Dataset {
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle(int thread_num = -1) = 0;
// for slots shuffle
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) = 0;
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
......@@ -183,6 +181,9 @@ class DatasetImpl : public Dataset {
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual Channel<T> GetInputChannel() { return input_channel_; }
virtual void SetInputChannel(const Channel<T>& input_channel) {
input_channel_ = input_channel;
}
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
......@@ -192,6 +193,7 @@ class DatasetImpl : public Dataset {
return data_feed_desc_;
}
virtual int GetChannelNum() { return channel_num_; }
virtual bool EnablePvMerge() { return enable_pv_merge_; }
virtual std::vector<paddle::framework::DataFeed*> GetReaders();
virtual void CreateChannel();
virtual void RegisterClientToClientMsgHandler();
......@@ -202,8 +204,9 @@ class DatasetImpl : public Dataset {
virtual void LocalShuffle();
virtual void GlobalShuffle(int thread_num = -1);
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace) {}
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {}
virtual const std::vector<T>& GetSlotsOriginalData() {
return slots_shuffle_original_data_;
}
virtual void CreateReaders();
virtual void DestroyReaders();
virtual int64_t GetMemoryDataSize();
......@@ -293,9 +296,13 @@ class MultiSlotDataset : public DatasetImpl<Record> {
}
std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
}
virtual void PreprocessChannel(
const std::set<std::string>& slots_to_replace,
std::unordered_set<uint16_t>& index_slot); // NOLINT
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual void GetRandomData(
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
virtual ~MultiSlotDataset() {}
};
......
......@@ -28,6 +28,8 @@ std::shared_ptr<BoxWrapper> BoxWrapper::s_instance_ = nullptr;
cudaStream_t BoxWrapper::stream_list_[8];
std::shared_ptr<boxps::BoxPSBase> BoxWrapper::boxps_ptr_ = nullptr;
AfsManager* BoxWrapper::afs_manager = nullptr;
int BoxWrapper::embedx_dim_ = 8;
int BoxWrapper::expand_embed_dim_ = 0;
void BasicAucCalculator::compute() {
double* table[2] = {&_table[0][0], &_table[1][0]};
......@@ -57,6 +59,94 @@ void BasicAucCalculator::compute() {
_size = fp + tp;
}
void BoxWrapper::CheckEmbedSizeIsValid(int embedx_dim, int expand_embed_dim) {
PADDLE_ENFORCE_EQ(
embedx_dim_, embedx_dim,
platform::errors::InvalidArgument("SetInstance(): invalid embedx_dim. "
"When embedx_dim = %d, but got %d.",
embedx_dim_, embedx_dim));
PADDLE_ENFORCE_EQ(expand_embed_dim_, expand_embed_dim,
platform::errors::InvalidArgument(
"SetInstance(): invalid expand_embed_dim. When "
"expand_embed_dim = %d, but got %d.",
expand_embed_dim_, expand_embed_dim));
}
void BoxWrapper::PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int expand_embed_dim) {
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define PULLSPARSE_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PullSparseCase<EmbedxDim, ExpandDim>(place, keys, values, slot_lengths, \
hidden_size, expand_embed_dim); \
} break
CheckEmbedSizeIsValid(hidden_size - 3, expand_embed_dim);
switch (hidden_size - 3) {
EMBEDX_CASE(8, PULLSPARSE_CASE(0); PULLSPARSE_CASE(8);
PULLSPARSE_CASE(64););
EMBEDX_CASE(16, PULLSPARSE_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
}
#undef PULLSPARSE_CASE
#undef EMBEDX_CASE
}
void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size,
const int expand_embed_dim,
const int batch_size) {
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define PUSHSPARSE_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PushSparseGradCase<EmbedxDim, ExpandDim>(place, keys, grad_values, \
slot_lengths, hidden_size, \
expand_embed_dim, batch_size); \
} break
CheckEmbedSizeIsValid(hidden_size - 3, expand_embed_dim);
switch (hidden_size - 3) {
EMBEDX_CASE(8, PUSHSPARSE_CASE(0); PUSHSPARSE_CASE(8);
PUSHSPARSE_CASE(64););
EMBEDX_CASE(16, PUSHSPARSE_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
}
#undef PUSHSPARSE_CASE
#undef EMBEDX_CASE
}
void BasicAucCalculator::calculate_bucket_error() {
double last_ctr = -1;
double impression_sum = 0;
......@@ -128,133 +218,112 @@ void BoxWrapper::EndPass(bool need_save_delta) const {
ret, 0, platform::errors::PreconditionNotMet("EndPass failed in BoxPS."));
}
void BoxWrapper::PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size) {
VLOG(3) << "Begin PullSparse";
platform::Timer all_timer;
platform::Timer pull_boxps_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf =
memory::AllocShared(place, total_length * sizeof(boxps::FeatureValueGpu));
boxps::FeatureValueGpu* total_values_gpu =
reinterpret_cast<boxps::FeatureValueGpu*>(buf->ptr());
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in PaddleBox now."));
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId();
LoDTensor& total_keys_tensor = keys_tensor[device_id];
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
// construct slot_level lod info
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*));
auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
this->CopyKeys(place, gpu_keys, total_keys, gpu_len,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length));
VLOG(3) << "Begin call PullSparseGPU in BoxPS";
pull_boxps_timer.Start();
int ret =
boxps_ptr_->PullSparseGPU(total_keys, total_values_gpu,
static_cast<int>(total_length), device_id);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PullSparseGPU failed in BoxPS."));
pull_boxps_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
this->CopyForPull(place, gpu_keys, values, total_values_gpu, gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
total_length);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddleBox: PullSparse Only Support CPUPlace or CUDAPlace Now."));
void BoxWrapper::GetRandomReplace(const std::vector<Record>& pass_data) {
VLOG(0) << "Begin GetRandomReplace";
size_t ins_num = pass_data.size();
replace_idx_.resize(ins_num);
for (auto& cand_list : random_ins_pool_list) {
cand_list.ReInitPass();
}
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, ins_num]() {
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomReplace begin for thread[" << tid
<< "], and process [" << start << ", " << end
<< "), total ins: " << ins_num;
auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
random_pool.AddAndGet(ins, replace_idx_[i]);
}
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
all_timer.Pause();
VLOG(1) << "PullSparse total costs: " << all_timer.ElapsedSec()
<< " s, of which BoxPS costs: " << pull_boxps_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PullSparse";
pass_done_semi_->Put(1);
VLOG(0) << "End GetRandomReplace";
}
void BoxWrapper::PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int batch_size) {
VLOG(3) << "Begin PushSparseGrad";
platform::Timer all_timer;
platform::Timer push_boxps_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared(
place, total_length * sizeof(boxps::FeaturePushValueGpu));
boxps::FeaturePushValueGpu* total_grad_values_gpu =
reinterpret_cast<boxps::FeaturePushValueGpu*>(buf->ptr());
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in PaddleBox now."));
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId();
LoDTensor& cached_total_keys_tensor = keys_tensor[device_id];
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>());
VLOG(3) << "Begin copy grad tensor to boxps struct";
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, total_length, batch_size);
void BoxWrapper::GetRandomData(
const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result) {
VLOG(0) << "Begin GetRandomData";
std::vector<std::thread> threads;
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads.push_back(std::thread([this, &pass_data, tid, &slots_to_replace,
result]() {
int debug_erase_cnt = 0;
int debug_push_cnt = 0;
size_t ins_num = pass_data.size();
int start = tid * ins_num / auc_runner_thread_num_;
int end = (tid + 1) * ins_num / auc_runner_thread_num_;
VLOG(3) << "GetRandomData begin for thread[" << tid << "], and process ["
<< start << ", " << end << "), total ins: " << ins_num;
const auto& random_pool = random_ins_pool_list[tid];
for (int i = start; i < end; ++i) {
const auto& ins = pass_data[i];
const RecordCandidate& rand_rec = random_pool.Get(replace_idx_[i]);
Record new_rec = ins;
for (auto it = new_rec.uint64_feasigns_.begin();
it != new_rec.uint64_feasigns_.end();) {
if (slots_to_replace.find(it->slot()) != slots_to_replace.end()) {
it = new_rec.uint64_feasigns_.erase(it);
debug_erase_cnt += 1;
} else {
++it;
}
}
for (auto slot : slots_to_replace) {
auto range = rand_rec.feas_.equal_range(slot);
for (auto it = range.first; it != range.second; ++it) {
new_rec.uint64_feasigns_.push_back({it->second, it->first});
debug_push_cnt += 1;
}
}
(*result)[i] = std::move(new_rec);
}
VLOG(3) << "thread[" << tid << "]: erase feasign num: " << debug_erase_cnt
<< " repush feasign num: " << debug_push_cnt;
}));
}
for (int tid = 0; tid < auc_runner_thread_num_; ++tid) {
threads[tid].join();
}
VLOG(0) << "End GetRandomData";
}
VLOG(3) << "Begin call PushSparseGPU in BoxPS";
push_boxps_timer.Start();
int ret = boxps_ptr_->PushSparseGPU(
total_keys, total_grad_values_gpu, static_cast<int>(total_length),
BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PushSparseGPU failed in BoxPS."));
push_boxps_timer.Pause();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddleBox: PushSparseGrad Only Support CPUPlace or CUDAPlace Now."));
void BoxWrapper::AddReplaceFeasign(boxps::PSAgentBase* p_agent,
int feed_pass_thread_num) {
VLOG(0) << "Enter AddReplaceFeasign Function";
int semi;
pass_done_semi_->Get(semi);
VLOG(0) << "Last Pass had updated random pool done. Begin AddReplaceFeasign";
std::vector<std::thread> threads;
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads.push_back(std::thread([this, tid, p_agent, feed_pass_thread_num]() {
VLOG(3) << "AddReplaceFeasign begin for thread[" << tid << "]";
for (size_t pool_id = tid; pool_id < random_ins_pool_list.size();
pool_id += feed_pass_thread_num) {
auto& random_pool = random_ins_pool_list[pool_id];
for (size_t i = 0; i < random_pool.Size(); ++i) {
auto& ins_candidate = random_pool.Get(i);
for (const auto& pair : ins_candidate.feas_) {
p_agent->AddKey(pair.second.uint64_feasign_, tid);
}
}
}
}));
}
all_timer.Pause();
VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
<< " s, of which BoxPS cost: " << push_boxps_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PushSparseGrad";
for (int tid = 0; tid < feed_pass_thread_num; ++tid) {
threads[tid].join();
}
VLOG(0) << "End AddReplaceFeasign";
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -27,9 +27,12 @@ namespace framework {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
__global__ void PullCopy(float** dest, const boxps::FeatureValueGpu* src,
const int64_t* len, int hidden, int slot_num,
int total_len, uint64_t** keys) {
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PullCopy(
float** dest,
const boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* src,
const int64_t* len, int hidden, int expand_dim, int slot_num, int total_len,
uint64_t** keys) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
......@@ -52,15 +55,28 @@ __global__ void PullCopy(float** dest, const boxps::FeatureValueGpu* src,
*(dest[x] + y * hidden + 2) = (src + i)->embed_w;
}
if ((src + i)->embedding_size == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < 8; j++) {
for (int j = 0; j < hidden - 3; j++) {
*(dest[x] + y * hidden + 3 + j) = 0;
}
} else {
for (int j = 0; j < 8; j++) {
for (int j = 0; j < hidden - 3; j++) {
*(dest[x] + y * hidden + 3 + j) = (src + i)->embedx[1 + j];
}
}
}
// process embed_expand
if (expand_dim > 0) {
int z = x + slot_num;
if ((src + i)->embed_expand_size[0] == 0 || *(keys[x] + y) == 0) {
for (int j = 0; j < expand_dim; j++) {
*(dest[z] + y * expand_dim + j) = 0;
}
} else {
for (int j = 0; j < expand_dim; j++) {
*(dest[z] + y * expand_dim + j) = (src + i)->embed_expand[1 + j];
}
}
}
} // end kernel loop
}
__global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
......@@ -82,9 +98,11 @@ __global__ void CopyKeysKernel(uint64_t** src_keys, uint64_t* dest_total_keys,
}
}
__global__ void PushCopy(boxps::FeaturePushValueGpu* dest, float** src,
int64_t* len, int hidden, int slot_num, int total_len,
int bs, int* slot_vector) {
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
__global__ void PushCopy(
boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* dest, float** src,
int64_t* len, int hidden, int expand_dim, int slot_num, int total_len,
int bs, int* slot_vector) {
CUDA_KERNEL_LOOP(i, total_len) {
int low = 0;
int high = slot_num - 1;
......@@ -101,18 +119,25 @@ __global__ void PushCopy(boxps::FeaturePushValueGpu* dest, float** src,
(dest + i)->show = *(src[x] + y * hidden);
(dest + i)->clk = *(src[x] + y * hidden + 1);
(dest + i)->embed_g = *(src[x] + y * hidden + 2) * -1. * bs;
for (int j = 0; j < 8; j++) {
for (int j = 0; j < hidden - 3; j++) {
(dest + i)->embedx_g[j] = *(src[x] + y * hidden + 3 + j) * -1. * bs;
}
if (expand_dim > 0) {
int z = x + slot_num;
for (int j = 0; j < expand_dim; j++) {
(dest + i)->embed_expand_g[j] =
*(src[z] + y * expand_dim + j) * -1. * bs;
}
}
}
}
void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
uint64_t** gpu_keys,
const std::vector<float*>& values,
const boxps::FeatureValueGpu* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size,
void* total_values_gpu, const int64_t* gpu_len,
const int slot_num, const int hidden_size,
const int expand_embed_dim,
const int64_t total_length) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
......@@ -122,11 +147,40 @@ void BoxWrapper::CopyForPull(const paddle::platform::Place& place,
float** gpu_values = reinterpret_cast<float**>(buf_value->ptr());
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define EXPAND_EMBED_PULL_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PullCopy<EmbedxDim, \
ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
gpu_values, \
reinterpret_cast<boxps::FeatureValueGpu<EmbedxDim, ExpandDim>*>( \
total_values_gpu), \
gpu_len, hidden_size, expand_embed_dim, slot_num, total_length, \
gpu_keys); \
} break
PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num,
total_length, gpu_keys);
switch (hidden_size - 3) {
EMBEDX_CASE(8, EXPAND_EMBED_PULL_CASE(0); EXPAND_EMBED_PULL_CASE(8);
EXPAND_EMBED_PULL_CASE(64););
EMBEDX_CASE(16, EXPAND_EMBED_PULL_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
}
cudaStreamSynchronize(stream);
#undef EXPAND_EMBED_PULL_CASE
#undef EMBEDX_CASE
}
void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
......@@ -143,10 +197,10 @@ void BoxWrapper::CopyKeys(const paddle::platform::Place& place,
void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
boxps::FeaturePushValueGpu* total_grad_values_gpu,
void* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size) {
const int hidden_size, const int expand_embed_dim,
const int64_t total_length, const int batch_size) {
auto stream = dynamic_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(
BOOST_GET_CONST(platform::CUDAPlace, place)))
......@@ -173,11 +227,42 @@ void BoxWrapper::CopyForPush(const paddle::platform::Place& place,
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, hidden_size,
slot_lengths.size(), total_length, batch_size, d_slot_vector);
#define EMBEDX_CASE(i, ...) \
case i: { \
constexpr size_t EmbedxDim = i; \
switch (expand_embed_dim) { \
__VA_ARGS__ \
default: \
PADDLE_THROW(platform::errors::InvalidArgument( \
"Unsupport this expand embedding size [%d]", expand_embed_dim)); \
} \
} break
#define EXPAND_EMBED_PUSH_CASE(i, ...) \
case i: { \
constexpr size_t ExpandDim = i; \
PushCopy<EmbedxDim, \
ExpandDim><<<(total_length + 512 - 1) / 512, 512, 0, stream>>>( \
reinterpret_cast<boxps::FeaturePushValueGpu<EmbedxDim, ExpandDim>*>( \
total_grad_values_gpu), \
gpu_values, gpu_len, hidden_size, expand_embed_dim, \
slot_lengths.size(), total_length, batch_size, d_slot_vector); \
} break
switch (hidden_size - 3) {
EMBEDX_CASE(8, EXPAND_EMBED_PUSH_CASE(0); EXPAND_EMBED_PUSH_CASE(8);
EXPAND_EMBED_PUSH_CASE(64););
EMBEDX_CASE(16, EXPAND_EMBED_PUSH_CASE(0););
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupport this embedding size [%d]", hidden_size - 3));
}
cudaStreamSynchronize(stream);
#undef EXPAND_EMBED_PUSH_CASE
#undef EMBEDX_CASE
}
} // end namespace framework
} // end namespace paddle
#endif
......@@ -31,10 +31,12 @@ limitations under the License. */
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -339,30 +341,54 @@ class BoxWrapper {
void BeginPass() const;
void EndPass(bool need_save_delta) const;
void SetTestMode(bool is_test) const;
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM = 0>
void PullSparseCase(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int expand_embed_dim);
void PullSparse(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size);
const int hidden_size, const int expand_embed_dim);
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM = 0>
void PushSparseGradCase(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int expand_embed_dim,
const int batch_size);
void PushSparseGrad(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int batch_size);
const int hidden_size, const int expand_embed_dim,
const int batch_size);
void CopyForPull(const paddle::platform::Place& place, uint64_t** gpu_keys,
const std::vector<float*>& values,
const boxps::FeatureValueGpu* total_values_gpu,
const std::vector<float*>& values, void* total_values_gpu,
const int64_t* gpu_len, const int slot_num,
const int hidden_size, const int64_t total_length);
const int hidden_size, const int expand_embed_dim,
const int64_t total_length);
void CopyForPush(const paddle::platform::Place& place,
const std::vector<const float*>& grad_values,
boxps::FeaturePushValueGpu* total_grad_values_gpu,
void* total_grad_values_gpu,
const std::vector<int64_t>& slot_lengths,
const int hidden_size, const int64_t total_length,
const int batch_size);
const int hidden_size, const int expand_embed_dim,
const int64_t total_length, const int batch_size);
void CopyKeys(const paddle::platform::Place& place, uint64_t** origin_keys,
uint64_t* total_keys, const int64_t* gpu_len, int slot_num,
int total_len);
void CheckEmbedSizeIsValid(int embedx_dim, int expand_embed_dim);
boxps::PSAgentBase* GetAgent() { return p_agent_; }
void InitializeGPUAndLoadModel(
const char* conf_file, const std::vector<int>& slot_vector,
......@@ -440,6 +466,15 @@ class BoxWrapper {
}
static std::shared_ptr<BoxWrapper> GetInstance() {
PADDLE_ENFORCE_EQ(
s_instance_ == nullptr, false,
platform::errors::PreconditionNotMet(
"GetInstance failed in BoxPs, you should use SetInstance firstly"));
return s_instance_;
}
static std::shared_ptr<BoxWrapper> SetInstance(int embedx_dim = 8,
int expand_embed_dim = 0) {
if (nullptr == s_instance_) {
// If main thread is guaranteed to init this, this lock can be removed
static std::mutex mutex;
......@@ -447,8 +482,13 @@ class BoxWrapper {
if (nullptr == s_instance_) {
VLOG(3) << "s_instance_ is null";
s_instance_.reset(new paddle::framework::BoxWrapper());
s_instance_->boxps_ptr_.reset(boxps::BoxPSBase::GetIns());
s_instance_->boxps_ptr_.reset(
boxps::BoxPSBase::GetIns(embedx_dim, expand_embed_dim));
embedx_dim_ = embedx_dim;
expand_embed_dim_ = expand_embed_dim;
}
} else {
LOG(WARNING) << "You have already used SetInstance() before";
}
return s_instance_;
}
......@@ -469,16 +509,16 @@ class BoxWrapper {
public:
MetricMsg() {}
MetricMsg(const std::string& label_varname, const std::string& pred_varname,
int is_join, int bucket_size = 1000000)
int metric_phase, int bucket_size = 1000000)
: label_varname_(label_varname),
pred_varname_(pred_varname),
is_join_(is_join) {
metric_phase_(metric_phase) {
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
virtual ~MetricMsg() {}
int IsJoin() const { return is_join_; }
int MetricPhase() const { return metric_phase_; }
BasicAucCalculator* GetCalculator() { return calculator; }
virtual void add_data(const Scope* exe_scope) {
std::vector<int64_t> label_data;
......@@ -514,20 +554,20 @@ class BoxWrapper {
protected:
std::string label_varname_;
std::string pred_varname_;
int is_join_;
int metric_phase_;
BasicAucCalculator* calculator;
};
class MultiTaskMetricMsg : public MetricMsg {
public:
MultiTaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname_list, int is_join,
const std::string& pred_varname_list, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
label_varname_ = label_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
is_join_ = is_join;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
......@@ -594,14 +634,14 @@ class BoxWrapper {
class CmatchRankMetricMsg : public MetricMsg {
public:
CmatchRankMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join,
const std::string& pred_varname, int metric_phase,
const std::string& cmatch_rank_group,
const std::string& cmatch_rank_varname,
int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
cmatch_rank_varname_ = cmatch_rank_varname;
is_join_ = is_join;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
for (auto& cmatch_rank : string::split_string(cmatch_rank_group)) {
......@@ -653,12 +693,12 @@ class BoxWrapper {
class MaskMetricMsg : public MetricMsg {
public:
MaskMetricMsg(const std::string& label_varname,
const std::string& pred_varname, int is_join,
const std::string& pred_varname, int metric_phase,
const std::string& mask_varname, int bucket_size = 1000000) {
label_varname_ = label_varname;
pred_varname_ = pred_varname;
mask_varname_ = mask_varname;
is_join_ = is_join;
metric_phase_ = metric_phase;
calculator = new BasicAucCalculator();
calculator->init(bucket_size);
}
......@@ -682,36 +722,59 @@ class BoxWrapper {
protected:
std::string mask_varname_;
};
const std::vector<std::string>& GetMetricNameList() const {
return metric_name_list_;
const std::vector<std::string> GetMetricNameList(
int metric_phase = -1) const {
VLOG(0) << "Want to Get metric phase: " << metric_phase;
if (metric_phase == -1) {
return metric_name_list_;
} else {
std::vector<std::string> ret;
for (const auto& name : metric_name_list_) {
const auto iter = metric_lists_.find(name);
PADDLE_ENFORCE_NE(
iter, metric_lists_.end(),
platform::errors::InvalidArgument(
"The metric name you provided is not registered."));
if (iter->second->MetricPhase() == metric_phase) {
VLOG(0) << name << "'s phase is " << iter->second->MetricPhase()
<< ", we want";
ret.push_back(name);
} else {
VLOG(0) << name << "'s phase is " << iter->second->MetricPhase()
<< ", not we want";
}
}
return ret;
}
}
int PassFlag() const { return pass_flag_; }
void FlipPassFlag() { pass_flag_ = 1 - pass_flag_; }
int Phase() const { return phase_; }
void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }
void InitMetric(const std::string& method, const std::string& name,
const std::string& label_varname,
const std::string& pred_varname,
const std::string& cmatch_rank_varname,
const std::string& mask_varname, bool is_join,
const std::string& mask_varname, int metric_phase,
const std::string& cmatch_rank_group,
int bucket_size = 1000000) {
if (method == "AucCalculator") {
metric_lists_.emplace(name, new MetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, bucket_size));
metric_phase, bucket_size));
} else if (method == "MultiTaskAucCalculator") {
metric_lists_.emplace(
name, new MultiTaskMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "CmatchRankAucCalculator") {
metric_lists_.emplace(
name, new CmatchRankMetricMsg(label_varname, pred_varname,
is_join ? 1 : 0, cmatch_rank_group,
metric_phase, cmatch_rank_group,
cmatch_rank_varname, bucket_size));
} else if (method == "MaskAucCalculator") {
metric_lists_.emplace(
name, new MaskMetricMsg(label_varname, pred_varname, is_join ? 1 : 0,
name, new MaskMetricMsg(label_varname, pred_varname, metric_phase,
mask_varname, bucket_size));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -751,9 +814,13 @@ class BoxWrapper {
const int feedpass_thread_num_ = 30; // magic number
static std::shared_ptr<BoxWrapper> s_instance_;
std::unordered_set<std::string> slot_name_omited_in_feedpass_;
// EMBEDX_DIM and EXPAND_EMBED_DIM
static int embedx_dim_;
static int expand_embed_dim_;
// Metric Related
int pass_flag_ = 1; // join: 1, update: 0
int phase_ = 1;
int phase_num_ = 2;
std::map<std::string, MetricMsg*> metric_lists_;
std::vector<std::string> metric_name_list_;
std::vector<int> slot_vector_;
......@@ -762,6 +829,57 @@ class BoxWrapper {
public:
static AfsManager* afs_manager;
// Auc Runner
public:
void InitializeAucRunner(std::vector<std::vector<std::string>> slot_eval,
int thread_num, int pool_size,
std::vector<std::string> slot_list) {
mode_ = 1;
phase_num_ = static_cast<int>(slot_eval.size());
phase_ = phase_num_ - 1;
auc_runner_thread_num_ = thread_num;
pass_done_semi_ = paddle::framework::MakeChannel<int>();
pass_done_semi_->Put(1); // Note: At most 1 pipeline in AucRunner
random_ins_pool_list.resize(thread_num);
std::unordered_set<std::string> slot_set;
for (size_t i = 0; i < slot_eval.size(); ++i) {
for (const auto& slot : slot_eval[i]) {
slot_set.insert(slot);
}
}
for (size_t i = 0; i < slot_list.size(); ++i) {
if (slot_set.find(slot_list[i]) != slot_set.end()) {
slot_index_to_replace_.insert(static_cast<int16_t>(i));
}
}
for (int i = 0; i < auc_runner_thread_num_; ++i) {
random_ins_pool_list[i].SetSlotIndexToReplace(slot_index_to_replace_);
}
VLOG(0) << "AucRunner configuration: thread number[" << thread_num
<< "], pool size[" << pool_size << "], runner_group[" << phase_num_
<< "]";
VLOG(0) << "Slots that need to be evaluated:";
for (auto e : slot_index_to_replace_) {
VLOG(0) << e << ": " << slot_list[e];
}
}
void GetRandomReplace(const std::vector<Record>& pass_data);
void AddReplaceFeasign(boxps::PSAgentBase* p_agent, int feed_pass_thread_num);
void GetRandomData(const std::vector<Record>& pass_data,
const std::unordered_set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
int Mode() const { return mode_; }
private:
int mode_ = 0; // 0 means train/test 1 means auc_runner
int auc_runner_thread_num_ = 1;
bool init_done_ = false;
paddle::framework::Channel<int> pass_done_semi_;
std::unordered_set<uint16_t> slot_index_to_replace_;
std::vector<RecordCandidateList> random_ins_pool_list;
std::vector<size_t> replace_idx_;
};
#endif
......@@ -810,7 +928,38 @@ class BoxHelper {
VLOG(3) << "After PreLoadIntoMemory()";
}
void WaitFeedPassDone() { feed_data_thread_->join(); }
void SlotsShuffle(const std::set<std::string>& slots_to_replace) {
#ifdef PADDLE_WITH_BOX_PS
auto box_ptr = BoxWrapper::GetInstance();
PADDLE_ENFORCE_EQ(box_ptr->Mode(), 1,
platform::errors::PreconditionNotMet(
"Should call InitForAucRunner first."));
box_ptr->FlipPhase();
std::unordered_set<uint16_t> index_slots;
dynamic_cast<MultiSlotDataset*>(dataset_)->PreprocessChannel(
slots_to_replace, index_slots);
const std::vector<Record>& pass_data =
dynamic_cast<MultiSlotDataset*>(dataset_)->GetSlotsOriginalData();
if (!get_random_replace_done_) {
box_ptr->GetRandomReplace(pass_data);
get_random_replace_done_ = true;
}
std::vector<Record> random_data;
random_data.resize(pass_data.size());
box_ptr->GetRandomData(pass_data, index_slots, &random_data);
auto new_input_channel = paddle::framework::MakeChannel<Record>();
new_input_channel->Open();
new_input_channel->Write(std::move(random_data));
new_input_channel->Close();
dynamic_cast<MultiSlotDataset*>(dataset_)->SetInputChannel(
new_input_channel);
if (dataset_->EnablePvMerge()) {
dataset_->PreprocessInstance();
}
#endif
}
#ifdef PADDLE_WITH_BOX_PS
// notify boxps to feed this pass feasigns from SSD to memory
static void FeedPassThread(const std::deque<Record>& t, int begin_index,
......@@ -881,6 +1030,10 @@ class BoxHelper {
for (size_t i = 0; i < tnum; ++i) {
threads[i].join();
}
if (box_ptr->Mode() == 1) {
box_ptr->AddReplaceFeasign(p_agent, tnum);
}
VLOG(3) << "Begin call EndFeedPass in BoxPS";
box_ptr->EndFeedPass(p_agent);
#endif
......@@ -892,7 +1045,10 @@ class BoxHelper {
int year_;
int month_;
int day_;
bool get_random_replace_done_ = false;
};
} // end namespace framework
} // end namespace paddle
#include "paddle/fluid/framework/fleet/box_wrapper_impl.h"
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_BOX_PS
#include <vector>
namespace paddle {
namespace framework {
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
void BoxWrapper::PullSparseCase(const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<float*>& values,
const std::vector<int64_t>& slot_lengths,
const int hidden_size,
const int expand_embed_dim) {
VLOG(3) << "Begin PullSparse";
platform::Timer all_timer;
platform::Timer pull_boxps_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared(
place, total_length *
sizeof(boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>));
boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>* total_values_gpu =
reinterpret_cast<boxps::FeatureValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>*>(
buf->ptr());
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in PaddleBox now."));
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
VLOG(3) << "Begin copy keys, key_num[" << total_length << "]";
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId();
LoDTensor& total_keys_tensor = keys_tensor[device_id];
uint64_t* total_keys = reinterpret_cast<uint64_t*>(
total_keys_tensor.mutable_data<int64_t>({total_length, 1}, place));
// construct slot_level lod info
auto slot_lengths_lod = slot_lengths;
for (size_t i = 1; i < slot_lengths_lod.size(); i++) {
slot_lengths_lod[i] += slot_lengths_lod[i - 1];
}
auto buf_key = memory::AllocShared(place, keys.size() * sizeof(uint64_t*));
auto buf_length =
memory::AllocShared(place, slot_lengths.size() * sizeof(int64_t));
uint64_t** gpu_keys = reinterpret_cast<uint64_t**>(buf_key->ptr());
int64_t* gpu_len = reinterpret_cast<int64_t*>(buf_length->ptr());
cudaMemcpy(gpu_keys, keys.data(), keys.size() * sizeof(uint64_t*),
cudaMemcpyHostToDevice);
cudaMemcpy(gpu_len, slot_lengths_lod.data(),
slot_lengths.size() * sizeof(int64_t), cudaMemcpyHostToDevice);
this->CopyKeys(place, gpu_keys, total_keys, gpu_len,
static_cast<int>(slot_lengths.size()),
static_cast<int>(total_length));
VLOG(3) << "Begin call PullSparseGPU in BoxPS";
pull_boxps_timer.Start();
int ret = boxps_ptr_->PullSparseGPU(
total_keys, reinterpret_cast<void*>(total_values_gpu),
static_cast<int>(total_length), device_id);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PullSparseGPU failed in BoxPS."));
pull_boxps_timer.Pause();
VLOG(3) << "Begin Copy result to tensor, total_length[" << total_length
<< "]";
this->CopyForPull(place, gpu_keys, values,
reinterpret_cast<void*>(total_values_gpu), gpu_len,
static_cast<int>(slot_lengths.size()), hidden_size,
expand_embed_dim, total_length);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddleBox: PullSparse Only Support CPUPlace or CUDAPlace Now."));
}
all_timer.Pause();
VLOG(1) << "PullSparse total costs: " << all_timer.ElapsedSec()
<< " s, of which BoxPS costs: " << pull_boxps_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PullSparse";
}
template <size_t EMBEDX_DIM, size_t EXPAND_EMBED_DIM>
void BoxWrapper::PushSparseGradCase(
const paddle::platform::Place& place,
const std::vector<const uint64_t*>& keys,
const std::vector<const float*>& grad_values,
const std::vector<int64_t>& slot_lengths, const int hidden_size,
const int expand_embed_dim, const int batch_size) {
VLOG(3) << "Begin PushSparseGrad";
platform::Timer all_timer;
platform::Timer push_boxps_timer;
all_timer.Start();
int64_t total_length =
std::accumulate(slot_lengths.begin(), slot_lengths.end(), 0UL);
auto buf = memory::AllocShared(
place,
total_length *
sizeof(boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>));
boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>*
total_grad_values_gpu = reinterpret_cast<
boxps::FeaturePushValueGpu<EMBEDX_DIM, EXPAND_EMBED_DIM>*>(
buf->ptr());
if (platform::is_cpu_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Warning:: CPUPlace is not supported in PaddleBox now."));
} else if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId();
LoDTensor& cached_total_keys_tensor = keys_tensor[device_id];
uint64_t* total_keys =
reinterpret_cast<uint64_t*>(cached_total_keys_tensor.data<int64_t>());
VLOG(3) << "Begin copy grad tensor to boxps struct";
this->CopyForPush(place, grad_values, total_grad_values_gpu, slot_lengths,
hidden_size, expand_embed_dim, total_length, batch_size);
VLOG(3) << "Begin call PushSparseGPU in BoxPS";
push_boxps_timer.Start();
int ret = boxps_ptr_->PushSparseGPU(
total_keys, reinterpret_cast<void*>(total_grad_values_gpu),
static_cast<int>(total_length),
BOOST_GET_CONST(platform::CUDAPlace, place).GetDeviceId());
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"PushSparseGPU failed in BoxPS."));
push_boxps_timer.Pause();
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Please compile WITH_GPU option, because NCCL doesn't support "
"windows."));
#endif
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddleBox: PushSparseGrad Only Support CPUPlace or CUDAPlace Now."));
}
all_timer.Pause();
VLOG(1) << "PushSparseGrad total cost: " << all_timer.ElapsedSec()
<< " s, of which BoxPS cost: " << push_boxps_timer.ElapsedSec()
<< " s";
VLOG(3) << "End PushSparseGrad";
}
} // namespace framework
} // namespace paddle
#endif
......@@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_double(eager_delete_tensor_gb);
......@@ -820,6 +821,8 @@ void ParallelExecutor::BCastParamsToDevices(
FetchResultType ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter ParallelExecutor Run";
platform::RecordEvent parallel_executor_event(
"ParallelExecutor::Run", paddle::platform::EventRole::kSpecial);
#ifdef WITH_GPERFTOOLS
if (gProfileStarted) {
ProfilerFlush();
......
......@@ -211,7 +211,7 @@ void SectionWorker::TrainFiles() {
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(exe_scope);
......@@ -367,7 +367,7 @@ void SectionWorker::TrainFilesWithProfiler() {
auto& metric_list = box_ptr->GetMetricList();
for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) {
auto* metric_msg = iter->second;
if (metric_msg->IsJoin() != box_ptr->PassFlag()) {
if (box_ptr->Phase() != metric_msg->MetricPhase()) {
continue;
}
metric_msg->add_data(exe_scope);
......
......@@ -43,7 +43,8 @@ class OpVariant {
const AttrType &Attr(const std::string &name) const {
auto &attrs = Attrs();
auto it = attrs.find(name);
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
PADDLE_ENFORCE_NE(it, attrs.end(), platform::errors::NotFound(
"Cannot find attribute %s.", name));
return BOOST_GET_CONST(AttrType, it->second);
}
......
......@@ -31,9 +31,9 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
int ind = in->numel();
for (size_t i = 0; i < (unsigned)ind; i++) {
if (input_data[i] < 0) {
output_data[i] = -std::pow(2.0, dict_data[input_data[i] + 128]);
output_data[i] = -dict_data[input_data[i] + 128];
} else {
output_data[i] = std::pow(2.0, dict_data[input_data[i]]);
output_data[i] = dict_data[input_data[i]];
}
}
}
......
......@@ -26,9 +26,9 @@ __global__ void KeDequantize(const T* in, const float* dict, int num,
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
if (in[idx] < 0) {
out[idx] = -std::pow(static_cast<float>(2.0), dict[in[idx] + 128]);
out[idx] = -dict[in[idx] + 128];
} else {
out[idx] = std::pow(static_cast<float>(2.0), dict[in[idx]]);
out[idx] = dict[in[idx]];
}
}
}
......
......@@ -104,7 +104,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
int axis = ctx.Attr<int>("axis");
int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
ctx.Input<Tensor>("Y")->dims().size();
return (axis == -1) || (axis == rankdiff);
return (rankdiff == 0) || (axis == -1) || (axis == rankdiff);
};
if (platform::CanMKLDNNBeUsed(ctx) &&
......@@ -243,9 +243,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddGradBeUsed = [&]() {
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
return (dx != nullptr && dy != nullptr && dx->dims() == dy->dims());
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};
if (platform::CanMKLDNNBeUsed(ctx) &&
......
......@@ -85,6 +85,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
in->set_format(out->format());
};
// TODO(jczaja): Double check if vcopy works for blocked data
auto blas = math::GetBlas<paddle::platform::CPUDeviceContext, T>(ctx);
if (dx) {
blas.VCOPY(dout->numel(), dout->data<T>(),
......
......@@ -257,7 +257,7 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
HierarchicalSigmoidGradOpNoNeedBufferVarInference, "Bias");
HierarchicalSigmoidGradOpNoNeedBufferVarInferer, "Bias");
} // namespace operators
} // namespace paddle
......@@ -270,7 +270,7 @@ REGISTER_OPERATOR(
ops::HierarchicalSigmoidGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp,
ops::HierarchicalSigmoidGradOpGradVarTypeInference,
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInference);
ops::HierarchicalSigmoidGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
hierarchical_sigmoid,
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -138,7 +138,7 @@ class IndexSelectGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
} // namespace paddle
......@@ -148,7 +148,7 @@ REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInference);
ops::IndexSelectGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -603,7 +603,7 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
}
};
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInference,
DECLARE_INPLACE_OP_INFERER(InstanceNormDoubleGradOpInplaceInferer,
{"DY", "DDY"});
} // namespace operators
......@@ -618,7 +618,7 @@ REGISTER_OPERATOR(instance_norm_grad, ops::InstanceNormGradOp,
ops::InstanceNormDoubleGradMaker<paddle::framework::OpDesc>,
ops::InstanceNormDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(instance_norm_grad_grad, ops::InstanceNormDoubleGradOp,
ops::InstanceNormDoubleGradOpInplaceInference);
ops::InstanceNormDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
instance_norm,
......
......@@ -585,7 +585,7 @@ class InterpolateGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(InterpolateGradNoNeedBufferVarsInferer,
"X");
} // namespace operators
......@@ -596,22 +596,22 @@ REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(trilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(trilinear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(bicubic_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(bicubic_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......@@ -631,7 +631,7 @@ REGISTER_OPERATOR(linear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
ops::InterpolateGradMaker<paddle::framework::OpDesc>,
ops::InterpolateGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_interp_grad, ops::InterpolateOpGrad,
ops::InterpolateGradNoNeedBufferVarsInference);
ops::InterpolateGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(linear_interp, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
......
......@@ -166,7 +166,7 @@ class KLDivLossOpGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(KLDivLossGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -176,7 +176,7 @@ REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker<paddle::framework::OpDesc>,
ops::KLDivLossOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad,
ops::KLDivLossGradNoNeedBufferVarInference);
ops::KLDivLossGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -220,7 +220,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LayerNormGradNoNeedBufferVarInferer,
"Bias");
} // namespace operators
......@@ -231,7 +231,7 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::framework::OpDesc>,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInference);
ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -345,7 +345,7 @@ class LinearChainCRFGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInference,
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LinearChainCRFGradNoNeedBufferVarsInferer,
"Transition", "Emission");
} // namespace operators
......@@ -357,7 +357,7 @@ REGISTER_OPERATOR(linear_chain_crf, ops::LinearChainCRFOp,
ops::LinearChainCRFGradMaker<paddle::framework::OpDesc>,
ops::LinearChainCRFGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(linear_chain_crf_grad, ops::LinearChainCRFGradOp,
ops::LinearChainCRFGradNoNeedBufferVarsInference);
ops::LinearChainCRFGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
linear_chain_crf,
ops::LinearChainCRFOpKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -223,7 +223,7 @@ DECLARE_INPLACE_OP_INFERER(LoDResetGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LoDResetGradNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -234,7 +234,7 @@ REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker,
ops::LoDResetGradMaker<paddle::imperative::OpBase>,
ops::LoDResetOpVarTypeInference, ops::LoDResetInplaceInferer);
REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp,
ops::LoDResetGradNoNeedBufferVarInference,
ops::LoDResetGradNoNeedBufferVarInferer,
ops::LoDResetGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
......
......@@ -130,7 +130,7 @@ or not. And the output only shares the LoD information with input Ids.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableGradOpNoBuffer, "W");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableGradOpNoBufferVarsInferer, "W");
template <typename T>
class LookupTableGradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -198,7 +198,7 @@ REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
ops::LookupTableGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableGradOpNoBuffer,
ops::LookupTableGradOpNoBufferVarsInferer,
ops::LookupTableOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
......
......@@ -118,7 +118,8 @@ or not. And the output only shares the LoD information with input Ids.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableV2GradOpNoBuffer, "W");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(LookupTableV2GradOpNoBufferVarsInferer,
"W");
template <typename T>
class LookupTableV2GradOpMaker : public framework::SingleGradOpMaker<T> {
......@@ -187,7 +188,7 @@ REGISTER_OPERATOR(lookup_table_v2, ops::LookupTableV2Op,
ops::LookupTableV2GradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad,
ops::LookupTableV2GradOpNoBuffer,
ops::LookupTableV2GradOpNoBufferVarsInferer,
ops::LookupTableV2OpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
......
......@@ -83,7 +83,7 @@ class MeanGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(MeanGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -93,7 +93,7 @@ REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanOpInferVarType,
ops::MeanGradMaker<paddle::framework::OpDesc>,
ops::MeanGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mean_grad, ops::MeanGradOp,
ops::MeanGradNoNeedBufferVarsInference);
ops::MeanGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
mean, ops::MeanKernel<paddle::platform::CPUDeviceContext, float>,
ops::MeanKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -62,8 +62,9 @@ class MKLDNNActivationGradKernel
template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto *x = ctx.Input<Tensor>("X");
......
......@@ -144,7 +144,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias data
......@@ -248,7 +252,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
PADDLE_ENFORCE_EQ(
scale_tz.size(), 1,
platform::errors::InvalidArgument(
"Dims of scale tensor must be 1, but received scale's size is %d",
scale_tz.size()));
const unsigned int C = scale_tz[0];
......
......@@ -134,6 +134,15 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
int concat_axis = ctx.Attr<int>("axis");
const int rank = multi_input[0]->dims().size();
PADDLE_ENFORCE_EQ(
concat_axis >= -rank && concat_axis < rank, true,
platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank, rank, concat_axis));
if (concat_axis < 0) {
concat_axis = concat_axis + rank;
}
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
auto place = GetCpuPlace(ctx);
......
......@@ -94,8 +94,9 @@ template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
platform::errors::InvalidArgument("It must use CPUPlace."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
if (!is_INT8) {
......@@ -784,9 +785,9 @@ template <typename T>
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
platform::errors::InvalidArgument("It must use CPUPlace."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL ConvGrad must use CPUPlace"));
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......
......@@ -29,9 +29,9 @@ template <typename T>
class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
platform::errors::InvalidArgument("It must use CPUPlace."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL ConvTranspose must use CPUPlace"));
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(is_test, true,
platform::errors::InvalidArgument(
......
......@@ -27,10 +27,12 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
const bool is_float_type = std::is_same<T, float>::value;
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"MKLDNN LRN must use CPUPlace.");
PADDLE_ENFORCE_EQ(
is_float_type, true,
platform::errors::PreconditionNotMet("DNNL LRN must use float data."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRN must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto x = ctx.Input<Tensor>("X");
......@@ -93,12 +95,16 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
const bool is_float_type = std::is_same<T, float>::value;
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"MKLDNN LRN must use CPUPlace.");
PADDLE_ENFORCE(
!ctx.Attr<bool>("is_test"),
"is_test attribute should be set to False in training phase.");
PADDLE_ENFORCE_EQ(is_float_type, true,
platform::errors::PreconditionNotMet(
"DNNL LRN GradOpKernl must use float data."));
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRNGrad must use CPUPlace"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::PreconditionNotMet(
"is_test attribute should be set to False in training phase."));
auto x = ctx.Input<Tensor>("X");
auto mid = ctx.Input<Tensor>("MidOut");
......
......@@ -30,12 +30,8 @@ class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Input<framework::Tensor>("X") != nullptr,
"Cannot get input tensor X, variable name = %s",
context.InputName("X"));
PADDLE_ENFORCE(context.Output<framework::Tensor>("Out") != nullptr,
"Cannot find output tensor Out, variable name = %s",
context.OutputName("Out"));
OP_INOUT_CHECK(context.HasInput("X"), "Input", "X", "Activation");
OP_INOUT_CHECK(context.HasInput("Out"), "Output", "Out", "Activation");
Functor functor;
auto attrs = functor.GetAttrs();
......
......@@ -333,9 +333,9 @@ template <typename XT, typename YT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Mul must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
......
......@@ -33,61 +33,19 @@ template <typename T>
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Pool must use CPUPlace"));
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor");
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
// Only 2D pooling is supported now
PADDLE_ENFORCE_EQ(ksize.size(), 2, "ksize must be 2D, i.e. 2D pooling");
PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true,
"pooling_type must be 'max' or 'avg'");
PADDLE_ENFORCE_EQ(input->dims().size(), 4,
"Input dim must be with 4, i.e. NCHW");
auto input_dims = input->dims();
framework::DDim data_dims =
framework::slice_ddim(input_dims, 2, input_dims.size());
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims,
strides, ksize);
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
auto is_test = ctx.Attr<bool>("is_test");
platform::PoolingMKLDNNHandler<T> handler(
src_tz, dst_tz, ksize, strides, paddings, pooling_type,
ctx.Attr<bool>("ceil_mode"), input->format(),
paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx,
ctx.GetPlace(), ctx.OutputName("Out"), ctx.Attr<bool>("exclusive"));
platform::PoolingMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine,
ctx.GetPlace(), input, output,
ctx.OutputName("Out"));
auto src_memory = handler.AcquireSrcMemory(input);
auto dst_memory = handler.AcquireDstMemory(output);
......@@ -95,7 +53,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto pool_p = handler.AcquireForwardPrimitive();
mkldnn::stream astream(dev_ctx.GetEngine());
if ((is_test == false) && (pooling_type == "max")) {
if ((ctx.Attr<bool>("is_test") == false) &&
(ctx.Attr<std::string>("pooling_type") == "max")) {
// Training
auto workspace_memory = handler.AcquireWorkspaceMemory();
pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
......@@ -117,9 +76,9 @@ template <typename T>
class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL PoolGrad must use CPUPlace"));
const Tensor* in_x = ctx.Input<Tensor>("X");
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -129,9 +129,9 @@ template <typename T>
class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL SoftmaxGrad must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const Tensor* output = ctx.Input<Tensor>("Out");
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -49,8 +49,9 @@ template <typename T>
class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Sum must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto in_vars = ctx.MultiInputVar("X");
......
......@@ -28,8 +28,9 @@ template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -73,8 +74,9 @@ template <typename T>
class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace"));
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
......
......@@ -51,7 +51,7 @@ void Communicator::InitAll(const std::vector<int>& gpus) {
for (size_t i = 0; i < gpus.size(); ++i) {
(*comm_id_map)[gpus[i]] = i;
}
PADDLE_ENFORCE(
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
inited = true;
}
......
......@@ -307,7 +307,7 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInference, "Bias");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NCEGradOpNoNeedBufferVarInferer, "Bias");
} // namespace operators
} // namespace paddle
......@@ -317,7 +317,7 @@ REGISTER_OPERATOR(nce, ops::NCEOp, ops::NCEOpMaker,
ops::NCEGradOpMaker<paddle::framework::OpDesc>,
ops::NCEGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(nce_grad, ops::NCEOpGrad, ops::NCEOpGradVarTypeInference,
ops::NCEGradOpNoNeedBufferVarInference);
ops::NCEGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel<paddle::platform::CPUPlace, float>,
ops::NCEKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(nce_grad,
......
......@@ -656,7 +656,7 @@ class Pad2dOpGradMaker : public framework::SingleGradOpMaker<T> {
};
// TODO(zjl): Paddings can also be skipped!
DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad2dOpGradNoNeedBufferVarsInference, "X");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(Pad2dOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -667,7 +667,7 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops::Pad2dOpGradMaker<paddle::framework::OpDesc>,
ops::Pad2dOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad,
ops::Pad2dOpGradNoNeedBufferVarsInference);
ops::Pad2dOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>,
ops::Pad2dCPUKernel<double>, ops::Pad2dCPUKernel<int>,
ops::Pad2dCPUKernel<int64_t>);
......
......@@ -316,7 +316,7 @@ class MaxPoolWithIndexGradOpMaker : public framework::SingleGradOpMaker<T> {
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
MaxPoolWithIndexOpGradNoNeedBufferVarsInference, "X");
MaxPoolWithIndexOpGradNoNeedBufferVarsInferer, "X");
} // namespace operators
} // namespace paddle
......@@ -328,7 +328,7 @@ REGISTER_OPERATOR(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool2d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index,
......@@ -347,7 +347,7 @@ REGISTER_OPERATOR(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPoolWithIndexGradOpMaker<paddle::framework::OpDesc>,
ops::MaxPoolWithIndexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(max_pool3d_with_index_grad, ops::MaxPoolWithIndexOpGrad,
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInference);
ops::MaxPoolWithIndexOpGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_extended_sparse_op.h"
namespace paddle {
namespace operators {
class PullBoxExtendedSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(
ctx->Inputs("Ids").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(Ids) of PullBoxExtendedSparseOp should not be empty."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of PullBoxExtendedSparseOp should not be empty."));
PADDLE_ENFORCE_GE(ctx->Outputs("OutExtend").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(OutExtend) of PullBoxExtendedSparseOp "
"should not be empty."));
auto emb_size = static_cast<int64_t>(ctx->Attrs().Get<int>("emb_size"));
auto emb_extended_size =
static_cast<int64_t>(ctx->Attrs().Get<int>("emb_extended_size"));
auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size();
std::vector<framework::DDim> outs_dims;
std::vector<framework::DDim> outs_extended_dims;
outs_dims.resize(n_ids);
outs_extended_dims.resize(n_ids);
for (size_t i = 0; i < n_ids; ++i) {
const auto ids_dims = all_ids_dim[i];
int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1.",
i));
auto out_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_dim.push_back(emb_size);
outs_dims[i] = framework::make_ddim(out_dim);
auto out_extended_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_extended_dim.push_back(emb_extended_size);
outs_extended_dims[i] = framework::make_ddim(out_extended_dim);
}
ctx->SetOutputsDim("Out", outs_dims);
ctx->SetOutputsDim("OutExtend", outs_extended_dims);
for (size_t i = 0; i < n_ids; ++i) {
ctx->ShareLoD("Ids", "Out", i, i);
ctx->ShareLoD("Ids", "OutExtend", i, i);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.device_context());
}
};
class PullBoxExtendedSparseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"Input tensors with type int32 or int64 "
"contains the ids to be looked up in BoxPS. "
"The last dimension size must be 1.")
.AsDuplicable();
AddOutput("Out", "The lookup results tensors.").AsDuplicable();
AddOutput("OutExtend", "The lookup extended results tensors.")
.AsDuplicable();
AddAttr<int>("emb_size", "(int, the embedding hidden size").SetDefault(1);
AddAttr<int>("emb_extended_size",
"(int, the extended_embedding hidden size")
.SetDefault(128);
AddComment(R"DOC(
Pull Box Extended Sparse Operator.
This operator is used to perform lookups on the BoxPS,
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
template <typename T>
class PushBoxExtendedSparseOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("push_box_extended_sparse");
op->SetInput("Ids", this->Input("Ids"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput(framework::GradVarName("OutExtend"),
this->OutputGrad("OutExtend"));
op->SetOutput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetAttrMap(this->Attrs());
}
};
class PushBoxExtendedSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
pull_box_extended_sparse, ops::PullBoxExtendedSparseOp,
ops::PullBoxExtendedSparseOpMaker,
ops::PushBoxExtendedSparseOpMaker<paddle::framework::OpDesc>,
ops::PushBoxExtendedSparseOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(push_box_extended_sparse, ops::PushBoxExtendedSparseOp);
REGISTER_OP_CPU_KERNEL(pull_box_extended_sparse,
ops::PullBoxExtendedSparseCPUKernel<float>,
ops::PullBoxExtendedSparseCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(push_box_extended_sparse,
ops::PushBoxExtendedSparseCPUKernel<float>,
ops::PushBoxExtendedSparseCPUKernel<double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/pull_box_extended_sparse_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
template <typename T>
class PullBoxExtendedSparseCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PullBoxExtendedSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushBoxExtendedSparseCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PushBoxExtendedSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pull_box_extended_sparse,
ops::PullBoxExtendedSparseCUDAKernel<float>,
ops::PullBoxExtendedSparseCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(push_box_extended_sparse,
ops::PushBoxExtendedSparseCUDAKernel<float>,
ops::PushBoxExtendedSparseCUDAKernel<double>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace operators {
template <typename T>
static void PullBoxExtendedSparseFunctor(
const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("Ids");
auto outputs = ctx.MultiOutput<framework::Tensor>("Out");
auto outputs_extend = ctx.MultiOutput<framework::Tensor>("OutExtend");
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
// BoxPS only supports float now
std::vector<float *> all_values(slot_size * 2);
std::vector<int64_t> slot_lengths(slot_size);
for (size_t i = 0; i < slot_size; i++) {
const auto *slot = inputs[i];
const uint64_t *single_slot_keys =
reinterpret_cast<const uint64_t *>(slot->data<int64_t>());
all_keys[i] = single_slot_keys;
slot_lengths[i] = slot->numel();
auto *output = outputs[i]->mutable_data<T>(ctx.GetPlace());
auto *output_extend = outputs_extend[i]->mutable_data<T>(ctx.GetPlace());
all_values[i] = reinterpret_cast<float *>(output);
all_values[i + slot_size] = reinterpret_cast<float *>(output_extend);
}
#ifdef PADDLE_WITH_BOX_PS
auto emb_size = ctx.Attr<int>("emb_size");
auto emb_extended_size = ctx.Attr<int>("emb_extended_size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths,
emb_size, emb_extended_size);
#endif
}
template <typename T>
static void PushBoxExtendedSparseFunctor(
const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::LoDTensor>("Ids");
auto d_output =
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"));
auto d_output_extend =
ctx.MultiInput<framework::Tensor>(framework::GradVarName("OutExtend"));
const auto slot_size = inputs.size();
std::vector<const uint64_t *> all_keys(slot_size);
std::vector<const float *> all_grad_values(slot_size * 2);
std::vector<int64_t> slot_lengths(slot_size);
int batch_size = -1;
for (size_t i = 0; i < slot_size; i++) {
const auto *slot = inputs[i];
const uint64_t *single_slot_keys =
reinterpret_cast<const uint64_t *>(slot->data<int64_t>());
all_keys[i] = single_slot_keys;
slot_lengths[i] = slot->numel();
int cur_batch_size =
slot->lod().size() ? slot->lod()[0].size() - 1 : slot->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
PADDLE_ENFORCE_EQ(batch_size, cur_batch_size,
platform::errors::PreconditionNotMet(
"The batch size of all input slots should be same,"
"please cheack"));
}
const float *grad_value = d_output[i]->data<float>();
const float *grad_value_extend = d_output_extend[i]->data<float>();
all_grad_values[i] = reinterpret_cast<const float *>(grad_value);
all_grad_values[i + slot_size] =
reinterpret_cast<const float *>(grad_value_extend);
}
#ifdef PADDLE_WITH_BOX_PS
auto emb_size = ctx.Attr<int>("emb_size");
auto emb_extended_size = ctx.Attr<int>("emb_extended_size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values,
slot_lengths, emb_size, emb_extended_size,
batch_size);
#endif
}
using LoDTensor = framework::LoDTensor;
template <typename T>
class PullBoxExtendedSparseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PullBoxExtendedSparseFunctor<T>(ctx);
}
};
template <typename T>
class PushBoxExtendedSparseCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PushBoxExtendedSparseFunctor<T>(ctx);
}
};
} // namespace operators
} // namespace paddle
......@@ -44,7 +44,7 @@ static void PullBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto hidden_size = ctx.Attr<int>("size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PullSparse(ctx.GetPlace(), all_keys, all_values, slot_lengths,
hidden_size);
hidden_size, 0);
#endif
}
......@@ -81,7 +81,7 @@ static void PushBoxSparseFunctor(const framework::ExecutionContext &ctx) {
auto hidden_size = ctx.Attr<int>("size");
auto box_ptr = paddle::framework::BoxWrapper::GetInstance();
box_ptr->PushSparseGrad(ctx.GetPlace(), all_keys, all_grad_values,
slot_lengths, hidden_size, batch_size);
slot_lengths, hidden_size, 0, batch_size);
#endif
}
......
......@@ -56,7 +56,7 @@ The input gradients is all dense gradient tensors in a table.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInference, "Ids");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(PushDenseNoNeedBufferVarsInferer, "Ids");
} // namespace operators
} // namespace paddle
......@@ -66,5 +66,5 @@ REGISTER_OPERATOR(
push_dense, ops::PushDenseOp, ops::PushDenseOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::PushDenseNoNeedBufferVarsInference);
ops::PushDenseNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(push_dense, ops::PushDenseCPUKernel<float>)
......@@ -34,9 +34,11 @@ class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity, bool speed_test_mode = false)
: capacity_(capacity), speed_test_mode_(speed_test_mode) {
PADDLE_ENFORCE_GT(
capacity_, static_cast<size_t>(0),
"The capacity of a reader::BlockingQueue must be greater than 0.");
PADDLE_ENFORCE_GT(capacity_, static_cast<size_t>(0),
platform::errors::InvalidArgument(
"The capacity of a reader::BlockingQueue must be "
"greater than 0, but received capacity is %d.",
capacity_));
}
bool Send(const T& elem) {
......@@ -49,7 +51,10 @@ class BlockingQueue {
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
PADDLE_ENFORCE_LT(
queue_.size(), capacity_,
platform::errors::PermissionDenied(
"The queue size cannot exceed the set queue capacity."));
queue_.push_back(elem);
receive_cv_.notify_one();
return true;
......@@ -65,7 +70,10 @@ class BlockingQueue {
<< "WARNING: Sending an element to a closed reader::BlokcingQueue.";
return false;
}
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
PADDLE_ENFORCE_LT(
queue_.size(), capacity_,
platform::errors::PermissionDenied(
"The queue size cannot exceed the set queue capacity."));
queue_.emplace_back(std::move(elem));
receive_cv_.notify_one();
return true;
......@@ -77,7 +85,9 @@ class BlockingQueue {
[&] { return !queue_.empty() || closed_ || killed_; });
EnforceNotKilled();
if (!queue_.empty()) {
PADDLE_ENFORCE_NOT_NULL(elem);
PADDLE_ENFORCE_NOT_NULL(
elem, platform::errors::InvalidArgument(
"The holder to receive queue data is null pointer."));
*elem = queue_.front();
if (LIKELY(!speed_test_mode_)) {
queue_.pop_front();
......@@ -85,7 +95,10 @@ class BlockingQueue {
send_cv_.notify_one();
return true;
} else {
PADDLE_ENFORCE(closed_);
PADDLE_ENFORCE_EQ(closed_, true,
platform::errors::PermissionDenied(
"Blocking queue status error, if queue is empty "
"when pop data, it should be closed."));
VLOG(3) << "queue is closed! return nothing.";
return false;
}
......@@ -136,9 +149,9 @@ class BlockingQueue {
private:
inline void EnforceNotKilled() {
PADDLE_ENFORCE_NE(
killed_, true,
"Blocking queue is killed because the data reader raises an exception");
PADDLE_ENFORCE_NE(killed_, true, platform::errors::Fatal(
"Blocking queue is killed because the "
"data reader raises an exception."));
}
private:
......
......@@ -62,7 +62,6 @@ BufferedReader::BufferedReader(
}
void BufferedReader::ReadTillBufferFullAsync() {
PADDLE_ENFORCE_EQ(position_.size(), 0U);
for (size_t i = 0; i < buffer_size_; ++i) {
ReadAsync(i);
}
......@@ -87,8 +86,10 @@ void BufferedReader::ReadAsync(size_t i) {
if (gpu.empty()) {
gpu.resize(cpu.size());
} else {
PADDLE_ENFORCE_EQ(gpu.size(), cpu.size(),
"Input tensor number not matched");
PADDLE_ENFORCE_EQ(
gpu.size(), cpu.size(),
platform::errors::InvalidArgument(
"Input tensor number on GPU and CPU devices are not matched."));
}
std::vector<void *> gpu_ptrs;
......
......@@ -36,8 +36,9 @@ class CreateCTRReaderOp : public framework::OperatorBase {
auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE_NOT_NULL(
queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name);
platform::errors::PreconditionNotMet(
"No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name));
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
......
......@@ -96,11 +96,14 @@ class CreateCustomReaderOpMaker : public DecoratedReaderMakerBase {
class CustomReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'CustomReaderInferShape' should only be invoked during "
"compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
PADDLE_ENFORCE_NE(
ctx->IsRuntime(), true,
platform::errors::PreconditionNotMet(
"'CustomReaderInferShape' should only be invoked during "
"compile time."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"The output decorated reader should not be null."));
const auto* sub_block =
ctx->Attrs().Get<framework::BlockDesc*>("sub_block");
const auto sink_var_names =
......@@ -109,7 +112,9 @@ class CustomReaderInferShape : public framework::InferShapeBase {
std::vector<int32_t> res_lod_levels;
for (const std::string& var_name : sink_var_names) {
auto* sink_var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(sink_var);
PADDLE_ENFORCE_NOT_NULL(
sink_var, platform::errors::NotFound(
"The sink variable is not found in CustomReader."));
res_dims.emplace_back(sink_var->GetShape());
res_lod_levels.push_back(sink_var->GetLoDLevel());
}
......@@ -124,7 +129,9 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const override {
auto& out_var_name = ctx->Output("Out")[0];
PADDLE_ENFORCE(ctx->HasVar(out_var_name));
PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true,
platform::errors::NotFound(
"The output reader variable should not be null."));
ctx->SetType(out_var_name, framework::proto::VarType::READER);
auto sink_var_names = BOOST_GET_CONST(std::vector<std::string>,
......@@ -134,7 +141,9 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
std::vector<framework::proto::VarType::Type> res_data_types;
for (const std::string& var_name : sink_var_names) {
framework::VarDesc* var = sub_block->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound(
"The sink variable is not found in CustomReader."));
res_data_types.emplace_back(var->GetDataType());
}
ctx->SetDataTypes(out_var_name, res_data_types);
......@@ -149,11 +158,13 @@ void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
// There is not next data.
return;
}
PADDLE_ENFORCE(source_var_names_.size() == underlying_outs.size(),
"The size of source_var_names(%d) and the size of "
"underlying_outs(%d) are not consistent. Each feeding element "
"must have its own source variable.",
source_var_names_.size(), underlying_outs.size());
PADDLE_ENFORCE_EQ(
source_var_names_.size(), underlying_outs.size(),
platform::errors::InvalidArgument(
"The size of source_var_names(%d) and the size of "
"underlying_outs(%d) are not consistent. Each feeding element "
"must have its own source variable.",
source_var_names_.size(), underlying_outs.size()));
// The scope for CustomReader's sub-block should be independent and shouldn't
// be any other computation scope's child. Otherwise, data preprocessing and
// compution cannot be concurrent.
......
......@@ -201,9 +201,10 @@ class OrderedMultiDeviceLoDTensorBlockingQueue {
class LoDTensorBlockingQueueHolder {
public:
void InitOnce(size_t capacity, bool speed_test_mode = false) {
PADDLE_ENFORCE(
queue_ == nullptr,
"LoDTensorBlockingQueueHolder::InitOnce() can only be called once");
PADDLE_ENFORCE_EQ(
queue_, nullptr,
platform::errors::AlreadyExists("LoDTensorBlockingQueueHolder::"
"InitOnce() can only be called once"));
queue_.reset(new LoDTensorBlockingQueue(capacity, speed_test_mode));
}
......
......@@ -25,7 +25,9 @@ PyReader::PyReader(
const std::vector<framework::proto::VarType::Type>& var_types,
const std::vector<bool>& need_check_feed)
: framework::FileReader(dims, var_types, need_check_feed) {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
PADDLE_ENFORCE_NOT_NULL(queue,
platform::errors::PreconditionNotMet(
"LoDTensorBlockingQueue must not be null."));
queue_ = queue;
}
......
......@@ -78,7 +78,10 @@ class ReadInferVarType : public framework::StaticGraphVarTypeInference {
std::string reader_name = Input(ctx, "Reader")[0];
auto& out_names = Output(ctx, "Out");
auto dtypes = GetDataTypes(ctx, reader_name);
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size(),
platform::errors::InvalidArgument(
"The number of input reader's dtypes do not match "
"the output variable number."));
for (size_t i = 0; i < dtypes.size(); ++i) {
SetType(ctx, out_names[i], framework::proto::VarType::LOD_TENSOR);
SetDataType(ctx, out_names[i], dtypes[i]);
......
......@@ -62,12 +62,14 @@ void FileReaderMakerBase::Make() {
}
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(
!ctx->IsRuntime(),
"'FileReaderInferShape' should only be invoked during compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
PADDLE_ENFORCE_NE(
ctx->IsRuntime(), true,
platform::errors::PreconditionNotMet("'FileReaderInferShape' should only "
"be invoked during compile time."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("The output file reader should not be null."));
bool use_data_config = ctx->Attrs().Get<bool>("use_data_config");
if (use_data_config) {
const auto shape_concat =
......@@ -77,21 +79,26 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
ctx->SetReaderDims("Out", shapes);
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
PADDLE_ENFORCE_EQ(
lod_levels.size(), shapes.size(),
platform::errors::InvalidArgument(
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size()));
const auto dtypes = ctx->Attrs().Get<std::vector<int>>("dtypes");
PADDLE_ENFORCE_EQ(
dtypes.size(), shapes.size(),
"The number of 'dtypes'(%d) doesn't match the number of 'shapes'(%d).",
dtypes.size(), shapes.size());
platform::errors::InvalidArgument("The number of 'dtypes'(%d) doesn't "
"match the number of 'shapes'(%d).",
dtypes.size(), shapes.size()));
const auto need_check_feed =
ctx->Attrs().Get<std::vector<int>>("need_check_feed");
PADDLE_ENFORCE_EQ(need_check_feed.size(), shapes.size(),
"The number of 'need_check_feed'(%d) doesn't match the "
"number of 'shapes'(%d).",
need_check_feed.size(), shapes.size());
PADDLE_ENFORCE_EQ(
need_check_feed.size(), shapes.size(),
platform::errors::InvalidArgument(
"The number of 'need_check_feed'(%d) doesn't match the "
"number of 'shapes'(%d).",
need_check_feed.size(), shapes.size()));
framework::VarDesc* reader =
BOOST_GET(framework::VarDesc*, ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
......@@ -105,14 +112,18 @@ void FileReaderInferVarType::operator()(
void DecoratedReaderInferShape::operator()(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'DecoratedReaderInferShape' should only be invoked during "
"compile time.");
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
PADDLE_ENFORCE_NE(
ctx->IsRuntime(), true,
platform::errors::PreconditionNotMet(
"'DecoratedReaderInferShape' should only be invoked during "
"compile time."));
PADDLE_ENFORCE_EQ(ctx->HasInput("UnderlyingReader"), true,
platform::errors::NotFound(
"Input(UnderlyingReader) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"The output decorated reader should not be null."));
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
framework::VarDesc* in_reader = BOOST_GET(
......
......@@ -545,12 +545,12 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
}
};
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInToOut, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInToOut,
DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ReshapeGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInToOut, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInference,
DECLARE_INPLACE_OP_INFERER(ReshapeDoubleGradInplaceInferer, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ReshapeDoubleGradOpNoNeedBufferVarInferer,
"DOut");
} // namespace operators
......@@ -562,9 +562,9 @@ REGISTER_OPERATOR(
reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>,
ops::ReshapeOpInplaceInToOut);
ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInToOut);
ops::ReshapeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
......@@ -576,14 +576,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
ops::Reshape2GradMaker<paddle::framework::OpDesc>,
ops::Reshape2GradMaker<paddle::imperative::OpBase>,
ops::ReshapeOpInplaceInToOut);
ops::ReshapeOpInplaceInferer);
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
ops::ReshapeGradInplaceInToOut);
ops::ReshapeGradInplaceInferer);
REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInToOut,
ops::ReshapeDoubleGradOpNoNeedBufferVarInference);
ops::ReshapeDoubleGradInplaceInferer,
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
......
......@@ -104,7 +104,7 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_INPLACE_OP_INFERER(ScaleOpInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
......@@ -113,7 +113,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
ops::ScaleGradMaker<paddle::framework::OpDesc>,
ops::ScaleGradMaker<paddle::imperative::OpBase>,
ops::ScaleOpVarTypeInference, ops::ScaleOpInplace);
ops::ScaleOpVarTypeInference, ops::ScaleOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
......@@ -20,15 +20,23 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
template <typename T>
class ShapeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input");
auto* in_var = ctx.InputVar("Input");
framework::DDim in_dims;
if (in_var->IsType<SelectedRows>()) {
in_dims = in_var->Get<SelectedRows>().value().dims();
} else {
in_dims = in_var->Get<LoDTensor>().dims();
}
auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()});
auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
auto in_dims = in_t->dims();
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
......
......@@ -287,10 +287,10 @@ class SoftmaxGradMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInference,
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyInplaceInferer,
{"Logits", "Softmax"});
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInference,
DECLARE_INPLACE_OP_INFERER(SoftmaxWithCrossEntropyGradInplaceInferer,
{"Softmax", framework::GradVarName("Logits")});
} // namespace operators
......@@ -302,10 +302,10 @@ REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker,
ops::SoftmaxGradMaker<paddle::framework::OpDesc>,
ops::SoftmaxGradMaker<paddle::imperative::OpBase>,
ops::SoftmaxWithCrossEntropyInplaceInference);
ops::SoftmaxWithCrossEntropyInplaceInferer);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad,
ops::SoftmaxWithCrossEntropyGradInplaceInference);
ops::SoftmaxWithCrossEntropyGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyKernel<float>,
ops::SoftmaxWithCrossEntropyKernel<double>);
......
......@@ -299,7 +299,7 @@ class SumGradOpBaseMaker : public imperative::GradOpBaseMakerBase {
}
};
DECLARE_INPLACE_OP_INFERER(SumInplace, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SumInplaceInferer, {"X", "Out"});
} // namespace operators
} // namespace paddle
......@@ -308,7 +308,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradDescMaker,
ops::SumGradOpBaseMaker, ops::SumOpVarTypeInference,
ops::SumInplace);
ops::SumInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sum, ops::SumKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -40,6 +40,9 @@ namespace {
thread_local std::deque<int> block_id_stack;
// Tracking the nested event stacks.
thread_local std::deque<Event *> annotation_stack;
// stack to strore event sunch as pe and so on
static std::deque<Event *> main_thread_annotation_stack{};
static std::deque<std::string> main_thread_annotation_stack_name{};
std::map<uint32_t, int32_t> system_thread_id_map;
......@@ -638,15 +641,49 @@ DeviceTracer *GetDeviceTracer() {
return tracer;
}
void SetCurAnnotation(Event *event) {
if (!annotation_stack.empty()) {
std::string SetCurAnnotation(Event *event) {
std::string ret;
if (!annotation_stack.empty() && event->role() != EventRole::kSpecial) {
event->set_parent(annotation_stack.back());
event->set_name(annotation_stack.back()->name() + "/" + event->name());
}
annotation_stack.push_back(event);
if (!main_thread_annotation_stack_name.empty() && !annotation_stack.empty() &&
main_thread_annotation_stack.back()->thread_id() !=
annotation_stack.back()->thread_id()) {
ret = main_thread_annotation_stack_name.back() + "/" + event->name();
} else {
ret = event->name();
}
if (event->role() == EventRole::kSpecial) {
std::string name = event->name();
if (!main_thread_annotation_stack_name.empty()) {
name = main_thread_annotation_stack_name.back() + "/" + event->name();
}
main_thread_annotation_stack_name.push_back(name);
main_thread_annotation_stack.push_back(event);
}
return ret;
}
void ClearCurAnnotation() { annotation_stack.pop_back(); }
void ClearCurAnnotation() {
if (!main_thread_annotation_stack_name.empty() && !annotation_stack.empty() &&
main_thread_annotation_stack.back()->thread_id() !=
annotation_stack.back()->thread_id()) {
annotation_stack.back()->set_name(main_thread_annotation_stack_name.back() +
"/" + annotation_stack.back()->name());
}
if (!main_thread_annotation_stack.empty() &&
main_thread_annotation_stack.back()->name() ==
annotation_stack.back()->name()) {
main_thread_annotation_stack_name.pop_back();
main_thread_annotation_stack.pop_back();
}
annotation_stack.pop_back();
}
Event *CurAnnotation() {
if (annotation_stack.empty()) return nullptr;
......
......@@ -137,7 +137,7 @@ class DeviceTracer {
DeviceTracer* GetDeviceTracer();
// Set a name for the cuda kernel operation being launched by the thread.
void SetCurAnnotation(Event* event);
std::string SetCurAnnotation(Event* event);
// Clear the name after the operation is done.
void ClearCurAnnotation();
// Current name of the operation being run in the thread.
......
......@@ -29,6 +29,7 @@ enum class EventRole {
kOrdinary, // only record op time with op type key
kInnerOp, // record op detail time with op type key
kUniqueOp, // record op detail time with op unique name key
kSpecial, // record event such as PE which is outer of thread local
};
class Event {
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "boost/optional.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h"
......@@ -592,41 +593,100 @@ template <typename T>
class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward> {
public:
PoolingMKLDNNHandler(
const std::vector<int64_t>& src_dims,
const std::vector<int64_t>& dst_dims, const std::vector<int64_t>& ksize,
const std::vector<int64_t>& strides, const std::vector<int64_t>& paddings,
const std::string& pooling_type, bool ceil_mode,
const MKLDNNMemoryFormat fmt, mkldnn::memory::data_type dt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place,
const std::string& unique_name, bool exclude_padding)
PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
Tensor* output, const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(src_dims, dt, unique_name)) {
auto src_md = mkldnn::memory::desc(src_dims, dt, fmt);
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
auto dst_md =
platform::MKLDNNMemDesc(dst_dims, dt, MKLDNNMemoryFormat::any);
platform::CreateKey(framework::vectorize(input->dims()),
framework::ToMKLDNNDataType(input->type()),
unique_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input tensor"));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input tensor"));
const std::string pooling_type = ctx.Attr<std::string>("pooling_type");
std::vector<int> ksize_temp = ctx.Attr<std::vector<int>>("ksize");
std::vector<int64_t> ksize(begin(ksize_temp), end(ksize_temp));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
const bool global_pooling = ctx.Attr<bool>("global_pooling");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
// Only 2D pooling is supported now
PADDLE_ENFORCE_EQ(ksize.size(), 2,
platform::errors::InvalidArgument(
"ksize must be 2D, i.e. 2D pooling"));
PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true,
platform::errors::InvalidArgument(
"pooling_type must be 'max' or 'avg'"));
PADDLE_ENFORCE_EQ(input->dims().size(), 4,
platform::errors::InvalidArgument(
"Input dim must be with 4, i.e. NCHW"));
const auto input_dims = input->dims();
framework::DDim data_dims =
framework::slice_ddim(input_dims, 2, input_dims.size());
if (global_pooling) {
operators::UpdateKsize(&ksize, data_dims);
}
auto mkldnn_paddings = ToMkldnnPadding(paddings);
operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm,
data_dims, strides, ksize);
const auto src_tz = paddle::framework::vectorize(input->dims());
const auto dst_tz = paddle::framework::vectorize(output->dims());
const auto is_test = ctx.Attr<bool>("is_test");
const auto dt = framework::ToMKLDNNDataType(input->type());
const auto fmt = input->format();
const auto exclude_padding = ctx.Attr<bool>("exclusive");
const auto src_md = mkldnn::memory::desc(src_tz, dt, fmt);
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
const auto dst_md =
platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
if (ceil_mode) {
CorrectOutputSize(src_dims, dst_dims, ksize, paddings, strides,
mkldnn_paddings[1]);
auto mkldnn_paddings = ToMkldnnPadding(paddings);
const bool ceil_mode = ctx.Attr<bool>("ceil_mode");
if (ceil_mode) {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
mkldnn_paddings[1]);
}
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0],
mkldnn_paddings[1]);
}
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
pooling_type == "max"
? mkldnn::algorithm::pooling_max
: (exclude_padding
? mkldnn::algorithm::pooling_avg_exclude_padding
: mkldnn::algorithm::pooling_avg_include_padding),
src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]);
}
PoolingMKLDNNHandler(
......@@ -1190,8 +1250,11 @@ static std::shared_ptr<mkldnn::memory> SetDstMemory(
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::vector<mkldnn::primitive>* pipeline) {
const T* residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(residual_param_data != nullptr,
"Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_NOT_NULL(
residual_param_data,
platform::errors::PreconditionNotMet("Residual parameter is required for "
"the DNNL conv+elementwise_add "
"fusion, but now it is missing"));
std::shared_ptr<mkldnn::memory> user_residual_memory_p =
handler->AcquireResidualDataMemory(user_residual_md,
to_void_cast<T>(residual_param_data));
......
......@@ -73,8 +73,7 @@ RecordEvent::RecordEvent(const std::string &name, const EventRole role) {
// lock is not needed, the code below is thread-safe
Event *e = PushEvent(name, role);
// Maybe need the same push/pop behavior.
SetCurAnnotation(e);
name_ = e->name();
name_ = SetCurAnnotation(e);
}
RecordEvent::~RecordEvent() {
......@@ -86,7 +85,7 @@ RecordEvent::~RecordEvent() {
BlockDepth(), g_thread_id);
}
ClearCurAnnotation();
PopEvent(name_);
PopEvent(name_, role_);
}
void MemEvenRecorder::PushMemRecord(const void *ptr, const Place &place,
......@@ -187,8 +186,8 @@ Event *PushEvent(const std::string &name, const EventRole role) {
return GetEventList().Record(EventType::kPushRange, name, g_thread_id, role);
}
void PopEvent(const std::string &name) {
GetEventList().Record(EventType::kPopRange, name, g_thread_id);
void PopEvent(const std::string &name, const EventRole role) {
GetEventList().Record(EventType::kPopRange, name, g_thread_id, role);
}
void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE_NE(state, ProfilerState::kDisabled,
......
......@@ -197,7 +197,7 @@ void PushMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
void PopMemEvent(uint64_t start_ns, uint64_t end_ns, size_t bytes,
const Place& place, const std::string& annotation);
Event* PushEvent(const std::string& name, const EventRole role);
void PopEvent(const std::string& name);
void PopEvent(const std::string& name, const EventRole role);
// Return the event list of all threads. Assumed the returned value calls
// event_lists, event_lists[i][j] represents the j-th Event of i-th thread.
std::vector<std::vector<Event>> GetAllEvents();
......
......@@ -22,12 +22,12 @@ limitations under the License. */
#include <memory>
#include <mutex> // NOLINT
#include <random>
#include <set>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
......@@ -283,7 +283,8 @@ std::function<bool(const EventItem &, const EventItem &)> SetSortedFunc(
void SetEvent(bool merge_thread, const Event &analyze_event,
size_t *max_name_width, std::list<Event> *pushed_events,
std::vector<EventItem> *event_items,
std::unordered_map<std::string, int> *event_idx) {
std::unordered_map<std::string, int> *event_idx,
const std::set<std::string> &main_thread_event_name) {
if (analyze_event.type() == EventType::kPushRange) {
pushed_events->push_back(analyze_event);
} else if (analyze_event.type() == EventType::kPopRange) {
......@@ -313,8 +314,35 @@ void SetEvent(bool merge_thread, const Event &analyze_event,
if (merge_thread) {
event_name = rit->name();
} else {
event_name =
"thread" + std::to_string(rit->thread_id()) + "::" + rit->name();
if (!main_thread_event_name.empty()) {
auto origin_name = rit->name();
int index = 1;
int split_pos = 0;
while ((split_pos = FindNthReversePos(origin_name, '/', index)) !=
-1) {
auto prefix_str = origin_name.substr(0, split_pos);
if (main_thread_event_name.count(prefix_str)) {
break;
}
index++;
}
if (split_pos == -1 && !main_thread_event_name.count(rit->name())) {
event_name = "thread" + std::to_string(rit->thread_id()) + "::" +
rit->name();
} else {
if (!main_thread_event_name.count(rit->name())) {
event_name =
origin_name.substr(0, split_pos + 1) + "thread" +
std::to_string(rit->thread_id()) + "::" +
origin_name.substr(split_pos + 1, origin_name.length() - 1);
} else {
event_name = rit->name();
}
}
} else {
event_name =
"thread" + std::to_string(rit->thread_id()) + "::" + rit->name();
}
}
auto print_name_size = event_name.size();
int found_pos = 0;
......@@ -608,6 +636,16 @@ void AnalyzeEvent(
std::function<bool(const EventItem &, const EventItem &)> sorted_func,
EventSortingKey sorted_by, size_t *max_name_width, OverHead *overhead,
bool merge_thread) {
// In oreder to deal with special event in main thread
std::set<std::string> main_thread_event_name;
for (size_t i = 0; i < (*analyze_events).size(); i++) {
for (size_t j = 0; j < (*analyze_events)[i].size(); j++) {
Event event = (*analyze_events)[i][j];
if (event.role() == EventRole::kSpecial) {
main_thread_event_name.insert(event.name());
}
}
}
for (size_t i = 0; i < (*analyze_events).size(); i++) {
double total = 0.; // the total time in one thread
std::list<Event> pushed_events;
......@@ -618,8 +656,10 @@ void AnalyzeEvent(
for (size_t j = 0; j < (*analyze_events)[i].size(); j++) {
Event analyze_event = (*analyze_events)[i][j];
SetEvent(merge_thread, analyze_event, max_name_width, &pushed_events,
&event_items, &event_idx);
if (!(analyze_event.role() == EventRole::kSpecial && !merge_thread)) {
SetEvent(merge_thread, analyze_event, max_name_width, &pushed_events,
&event_items, &event_idx, main_thread_event_name);
}
}
auto table_size = event_items.size();
......
......@@ -59,7 +59,7 @@ TEST(RecordEvent, RecordEvent) {
PushEvent(name, EventRole::kOrdinary);
int counter = 1;
while (counter != i * 1000) counter++;
PopEvent(name);
PopEvent(name, EventRole::kOrdinary);
}
}
......@@ -109,7 +109,7 @@ TEST(RecordEvent, RecordEvent) {
// Bad Usage:
PushEvent("event_without_pop", EventRole::kOrdinary);
PopEvent("event_without_push");
PopEvent("event_without_push", EventRole::kOrdinary);
std::vector<std::vector<Event>> events = paddle::platform::GetAllEvents();
int cuda_startup_count = 0;
......
......@@ -54,6 +54,8 @@ void BindBoxHelper(py::module* m) {
.def("preload_into_memory", &framework::BoxHelper::PreLoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::BoxHelper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>())
.def("slots_shuffle", &framework::BoxHelper::SlotsShuffle,
py::call_guard<py::gil_scoped_release>());
} // end BoxHelper
......@@ -61,9 +63,9 @@ void BindBoxHelper(py::module* m) {
void BindBoxWrapper(py::module* m) {
py::class_<framework::BoxWrapper, std::shared_ptr<framework::BoxWrapper>>(
*m, "BoxWrapper")
.def(py::init([]() {
.def(py::init([](int embedx_dim, int expand_embed_dim) {
// return std::make_shared<paddle::framework::BoxHelper>(dataset);
return framework::BoxWrapper::GetInstance();
return framework::BoxWrapper::SetInstance(embedx_dim, expand_embed_dim);
}))
.def("save_base", &framework::BoxWrapper::SaveBase,
py::call_guard<py::gil_scoped_release>())
......@@ -76,13 +78,15 @@ void BindBoxWrapper(py::module* m) {
.def("initialize_gpu_and_load_model",
&framework::BoxWrapper::InitializeGPUAndLoadModel,
py::call_guard<py::gil_scoped_release>())
.def("initialize_auc_runner", &framework::BoxWrapper::InitializeAucRunner,
py::call_guard<py::gil_scoped_release>())
.def("init_metric", &framework::BoxWrapper::InitMetric,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_msg", &framework::BoxWrapper::GetMetricMsg,
py::call_guard<py::gil_scoped_release>())
.def("get_metric_name_list", &framework::BoxWrapper::GetMetricNameList,
py::call_guard<py::gil_scoped_release>())
.def("flip_pass_flag", &framework::BoxWrapper::FlipPassFlag,
.def("flip_phase", &framework::BoxWrapper::FlipPhase,
py::call_guard<py::gil_scoped_release>())
.def("init_afs_api", &framework::BoxWrapper::InitAfsAPI,
py::call_guard<py::gil_scoped_release>())
......
......@@ -291,6 +291,8 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("set_fleet_send_sleep_seconds",
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>());
py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
......
......@@ -116,7 +116,7 @@ python setup.py install
"""
self.cuda100 = r"""
- cudatoolkit>=10.0, <10.1
- cudnn>=7.3, <7.4
- cudnn>=7.6, <7.7
"""
self.cuda_info = [(self.cuda90, "cuda9.0", ".post97"),
(self.cuda100, "cuda10.0", ".post107")]
......
......@@ -59,9 +59,9 @@ function init() {
}
function cmake_base() {
# build script will not fail if *.deb does not exist
# Build script will not fail if *.deb does not exist
rm *.deb 2>/dev/null || true
# delete previous built whl packages
# Delete previous built whl packages
rm -rf python/dist 2>/dev/null || true
# Support build for all python versions, currently
......@@ -199,9 +199,7 @@ function cmake_base() {
-DWITH_DISTRIBUTE=${distibuted_flag}
-DWITH_MKL=${WITH_MKL:-ON}
-DWITH_AVX=${WITH_AVX:-OFF}
-DWITH_GOLANG=${WITH_GOLANG:-OFF}
-DCUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All}
-DCUDA_ARCH_BIN=${CUDA_ARCH_BIN}
-DWITH_PYTHON=${WITH_PYTHON:-ON}
-DCUDNN_ROOT=/usr/
-DWITH_TESTING=${WITH_TESTING:-ON}
......@@ -231,9 +229,7 @@ EOF
-DWITH_MKL=${WITH_MKL:-ON} \
-DWITH_AVX=${WITH_AVX:-OFF} \
-DNOAVX_CORE_FILE=${NOAVX_CORE_FILE:-""} \
-DWITH_GOLANG=${WITH_GOLANG:-OFF} \
-DCUDA_ARCH_NAME=${CUDA_ARCH_NAME:-All} \
-DCUDA_ARCH_BIN=${CUDA_ARCH_BIN} \
-DWITH_PYTHON=${WITH_PYTHON:-ON} \
-DCUDNN_ROOT=/usr/ \
-DWITH_TESTING=${WITH_TESTING:-ON} \
......@@ -1080,7 +1076,7 @@ EOF
if [[ "$1" != "" ]]; then
parallel_number=$1
fi
cmake .. -DWITH_DISTRIBUTE=OFF -DON_INFER=ON -DCUDA_ARCH_NAME=${CUDA_ARCH_NAME:-Auto} -DCUDA_ARCH_BIN=${CUDA_ARCH_BIN}
cmake .. -DWITH_DISTRIBUTE=OFF -DON_INFER=ON -DCUDA_ARCH_NAME=${CUDA_ARCH_NAME:-Auto}
make -j ${parallel_number} fluid_lib_dist
make -j ${parallel_number} inference_lib_dist
......
......@@ -34,7 +34,8 @@ __all__ = [
'fused_elemwise_activation', 'sequence_topk_avg_pooling', 'var_conv_2d',
'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool',
'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat',
'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc'
'partial_sum', 'tdm_child', 'rank_attention', 'tdm_sampler', 'batch_fc',
'_pull_box_extended_sparse'
]
......@@ -1361,3 +1362,50 @@ def batch_fc(input, param_size, param_attr, bias_size, bias_attr, act=None):
"Bias": b},
outputs={"Out": pre_act})
return helper.append_activation(pre_act)
def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'):
"""
**Pull Box Extended Sparse Layer**
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
BoxPS lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
Args:
input(Variable|list of Variable): Input is a Tensor<int64> Variable, which
contains the IDs information.
size(int): The embedding size parameter, which indicates the size of
each embedding vector respectively.
extend_size(int): The embedding size parameter in extended dim,
which indicates the size of each embedding vector respectively.
dtype(str): The dtype refers to the data type of output tensor. Only supports
float32 now.
Returns:
Variable|list of Variable: The tensor variable storing the embeddings of the \
supplied inputs.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1)
emb, emb_ex = fluid.contrib.layers._pull_box_extended_sparse(input=data, size=8, extend_size=128)
"""
helper = LayerHelper('pull_box_extended_sparse', **locals())
helper.input_dtype()
inputs = helper.multiple_input()
outs = [
helper.create_variable_for_type_inference(dtype)
for i in range(len(inputs))
]
outs_extend = [
helper.create_variable_for_type_inference(dtype)
for i in range(len(inputs))
]
helper.append_op(
type='pull_box_extended_sparse',
inputs={'Ids': inputs},
outputs={'Out': outs,
'OutExtend': outs_extend},
attrs={'emb_size': size,
'emb_extended_size': extend_size})
if len(outs) == 1:
return outs[0], outs_extend[0]
return outs, outs_extend
......@@ -43,7 +43,7 @@ _fake_quant_dequant_op_list = [
_out_scale_op_list = [
"conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu",
"relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm",
"elementwise_add", "pool2d", "reshape2", "transpose2"
"elementwise_add", "pool2d", "reshape2", "transpose2", "concat"
]
# list op real input and output names, to avoid processing input such as AxisTensor.
......@@ -1156,14 +1156,13 @@ class OutScaleForTrainingPass(object):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._teller_set:
if len(op_node.output_arg_names()) != 1:
continue
in_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0])
target_ops = []
for op in graph.all_op_nodes():
if op.name() in self._teller_set:
target_ops.append(op)
for op in target_ops:
for output_var_name in _get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs, output_var_name)
out_node = graph.create_var_node_from_desc(in_node.var())
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
......@@ -1263,13 +1262,13 @@ class OutScaleForInferencePass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._teller_set:
if len(op_node.output_arg_names()) != 1:
continue
scale_name = self._scale_name(op_node.output_arg_names()[0])
op_nodes = graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in self._teller_set:
output_var_name = _get_op_output_var_names(op_node)
assert len(output_var_name) == 1, "Only support collecting " \
"output for op that only has an activation output for now."
scale_name = self._scale_name(output_var_name[0])
scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_threshold", float(scale_v))
......
......@@ -1079,3 +1079,24 @@ class BoxPSDataset(InMemoryDataset):
def _dynamic_adjust_after_train(self):
pass
def slots_shuffle(self, slots):
"""
Slots Shuffle
Slots Shuffle is a shuffle method in slots level, which is usually used
in sparse feature with large scale of instances. To compare the metric, i.e.
auc while doing slots shuffle on one or several slots with baseline to
evaluate the importance level of slots(features).
Args:
slots(list[string]): the set of slots(string) to do slots shuffle.
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_merge_by_lineid()
#suppose there is a slot 0
dataset.slots_shuffle(['0'])
"""
slots_set = set(slots)
self.boxps.slots_shuffle(slots_set)
......@@ -32,12 +32,23 @@ class CallTransformer(gast.NodeTransformer):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def _is_builtin_call(self, node):
def _no_need_convert_call(self, node):
"""
Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle
2. It's a python builtin function not include `len`
"""
assert isinstance(node, gast.Call)
if is_paddle_api(node):
return True
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
return eval("is_builtin({})".format(func_str))
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin
is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str))
return is_builtin and not is_builtin_len
except Exception:
return False
......@@ -46,10 +57,8 @@ class CallTransformer(gast.NodeTransformer):
def visit_Call(self, node):
self.generic_visit(node)
if is_paddle_api(node):
return node
if self._is_builtin_call(node):
if self._no_need_convert_call(node):
return node
func_str = ast_to_source_code(node.func).strip()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid.layers import nn
from paddle.fluid.layers import control_flow
def convert_len(var):
"""
return variable(length) from shape ops based on var.type
Note: In addition to some ast transformations, some block-related
operations are added in `len` transformation, such as appending
`shape_op` in var.block.
"""
if isinstance(var, framework.Variable):
if var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
]:
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var)
else:
raise TypeError(
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var))
else:
return len(var)
......@@ -29,6 +29,7 @@ import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.dygraph.dygraph_to_static.convert_builtins_func import convert_len
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator()
......@@ -49,6 +50,12 @@ def is_builtin(func):
return False
def is_builtin_len(func):
if isinstance(func, types.BuiltinFunctionType) and func.__name__ == 'len':
return True
return False
def is_paddle_func(func):
m = inspect.getmodule(func)
return m is not None and m.__name__.startswith("paddle")
......@@ -91,10 +98,10 @@ def convert_call(func):
func_self = None
converted_call = None
if is_builtin(func):
return func
if is_builtin_len(func):
return convert_len
if is_paddle_func(func):
if is_builtin(func) or is_paddle_func(func):
return func
if inspect.isfunction(func):
......
......@@ -166,13 +166,19 @@ class NameVisitor(gast.NodeVisitor):
in_loop_vars = self.in_loop_vars[node]
in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
before_loop_body_vars = self.before_loop_body_vars[node]
before_loop_body_vars = self._remove_target_vars_of_for(
before_loop_body_vars, node)
before_loop_name_strs = self._var_nodes_to_names(before_loop_body_vars)
after_loop_vars = self.current_seen_vars - before_loop_body_vars - in_loop_vars
after_loop_vars = self._remove_target_vars_of_for(after_loop_vars, node)
after_loop_name_strs = self._var_nodes_to_names(after_loop_vars,
read_context)
condition_vars = self.condition_vars[node]
condition_names = self._var_nodes_to_names(condition_vars)
write_vars = self.write_in_loop[node]
write_names = self._var_nodes_to_names(write_vars)
......@@ -203,6 +209,7 @@ class NameVisitor(gast.NodeVisitor):
# vars out
loop_var_names.add(name)
create_var_names.add(name)
return loop_var_names, create_var_names
def visit_Name(self, node):
......@@ -221,8 +228,8 @@ class NameVisitor(gast.NodeVisitor):
self.in_loop_vars[loop_node].add(node)
if type(node.ctx) in write_context:
self.write_in_loop[loop_node].add(node)
if self.in_condition:
self.condition_vars[loop_node].add(node)
if self.in_condition:
self.condition_vars[loop_node].add(node)
self.generic_visit(node)
def visit_FunctionDef(self, node):
......@@ -309,11 +316,60 @@ class NameVisitor(gast.NodeVisitor):
return False
def _is_call_func_name_node(self, node):
parent_node = self.node_to_wrapper_map[node].parent.node
parent_node = self._get_parent_node(node)
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
return False
def _get_parent_node(self, node):
wrapper_node = self.node_to_wrapper_map.get(node)
if wrapper_node:
parent_node = wrapper_node.parent.node
return parent_node
return None
def _remove_target_vars_of_for(self, before_or_after_loop_vars, loop_node):
"""
Remove target vars of gast.For from before_loop_vars or after_loop_vars.
:param before_or_after_loop_vars: before_loop_vars or after_loop_vars of loop_node.
:param loop_node: Current loop node.
"""
removed_vars = set()
for name_node in before_or_after_loop_vars:
if not isinstance(name_node, gast.Name):
continue
parent_node = self._get_parent_node(name_node)
# NOTE: gast.For.target can be gast.Tuple.
# For example: `for i, j in enumerate(x)` has two target vars: i and j
if isinstance(parent_node, gast.Tuple):
parent_node = self._get_parent_node(parent_node)
if isinstance(parent_node,
gast.For) and parent_node is not loop_node:
target_node = parent_node.target
if isinstance(target_node, gast.Tuple):
target_vars = target_node.elts
else:
target_vars = [target_node]
if name_node in target_vars:
removed_vars.add(name_node)
removed_vars_name_strs = {var.id for var in removed_vars}
for var in before_or_after_loop_vars:
if not isinstance(var, gast.Name):
continue
if var.id in removed_vars_name_strs and var not in self.condition_vars[
loop_node]:
removed_vars.add(var)
return before_or_after_loop_vars - removed_vars
class LoopTransformer(gast.NodeTransformer):
"""
......
......@@ -771,14 +771,19 @@ class Pool2D(layers.Layer):
ceil_mode (bool, optional): Whether to use the ceil function to calculate output height and width.
False is the default. If it is set to False, the floor function will be used. Default: False.
exclusive (bool, optional): Whether to exclude padding points in average pooling mode. Default: True.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
``[batch_size, input_channels, input_height, input_width]``. When it is `"NHWC"`, the data is
stored in the order of: ``[batch_size, input_height, input_width, input_channels]``
Returns:
None
Raises:
ValueError: If 'pool_type' is not "max" nor "avg"
ValueError: If 'global_pooling' is False and 'pool_size' is -1
ValueError: If 'use_cudnn' is not a bool value.
ValueError: If ``pool_type`` is not "max" nor "avg".
ValueError: If ``global_pooling`` is False and ``pool_size`` is -1.
ValueError: If ``use_cudnn`` is not a bool value.
ValueError: If ``data_format`` is not "NCHW" nor "NHWC".
Examples:
......@@ -806,7 +811,10 @@ class Pool2D(layers.Layer):
global_pooling=False,
use_cudnn=True,
ceil_mode=False,
exclusive=True):
exclusive=True,
data_format="NCHW"):
data_format = data_format.upper() # supprt NHWC, nhwc, etc.
pool_type = pool_type.lower() # supprt max, Max, etc.
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
......@@ -820,6 +828,11 @@ class Pool2D(layers.Layer):
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCHW' or 'NHWC'. Received "
"Attr(data_format): %s." % str(data_format))
super(Pool2D, self).__init__()
self._pool_type = pool_type
......@@ -831,6 +844,7 @@ class Pool2D(layers.Layer):
self._use_cudnn = use_cudnn
self._ceil_mode = ceil_mode
self._exclusive = exclusive
self._data_format = data_format
self._l_type = 'pool2d'
def forward(self, input):
......@@ -839,7 +853,8 @@ class Pool2D(layers.Layer):
'global_pooling', self._global_pooling, 'strides',
self._pool_stride, 'paddings', self._pool_padding,
'use_cudnn', self._use_cudnn, 'ceil_mode', self._ceil_mode,
'use_mkldnn', False, 'exclusive', self._exclusive)
'use_mkldnn', False, 'exclusive', self._exclusive,
'data_format', self._data_format)
return core.ops.pool2d(input, *attrs)
check_variable_and_dtype(
......@@ -856,6 +871,7 @@ class Pool2D(layers.Layer):
"ceil_mode": self._ceil_mode,
"use_mkldnn": False,
"exclusive": self._exclusive,
"data_format": self._data_format,
}
inputs = {"X": [input]}
......
......@@ -1536,9 +1536,11 @@ def teacher_student_sigmoid_loss(input,
cost = fluid.layers.teacher_student_sigmoid_loss(input=similarity, label=label)
"""
check_variable_and_dtype(input, "input", ['float32', 'float64'],
check_variable_and_dtype(input, "input",
['float32', 'float64', 'int32', 'int64'],
'teacher_student_sigmoid_loss')
check_variable_and_dtype(label, "label", ['float32', 'float64'],
check_variable_and_dtype(label, "label",
['float32', 'float64', 'int32', 'int64'],
'teacher_student_sigmoid_loss')
helper = LayerHelper('teacher_student_sigmoid_loss', **locals())
......
......@@ -1902,7 +1902,7 @@ def pool2d(input,
None by default.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is `true`.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
......@@ -11045,8 +11045,26 @@ def shape(input):
Get the shape of the input.
.. code-block:: text
Case1:
Given N-D Tensor:
input = [ [1, 2, 3, 4], [5, 6, 7, 8] ]
Then:
input.shape = [2, 4]
Case2:
Given SelectedRows:
input.rows = [0, 4, 19]
input.height = 20
input.value = [ [1, 2], [3, 4], [5, 6] ] # inner tensor
Then:
input.shape = [3, 2]
Args:
input (Variable): The input N-D Tensor. Datatype can be float32, float64, int32, int64.
input (Variable): The input can be N-D Tensor or SelectedRows with data type float32, float64, int32, int64.
If input variable is type of SelectedRows, returns the shape of it's inner tensor.
Returns:
Variable (Tensor): The shape of the input variable.
......@@ -11057,7 +11075,7 @@ def shape(input):
import paddle.fluid as fluid
import numpy as np
inputs = fluid.layers.data(name="x", shape=[3, 100, 100], dtype="float32")
inputs = fluid.data(name="x", shape=[3, 100, 100], dtype="float32")
output = fluid.layers.shape(inputs)
exe = fluid.Executor(fluid.CPUPlace())
......
......@@ -49,10 +49,13 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
max_len = max([len(sent) for sent in batch_tokens])
mask_label = []
mask_pos = []
np.random.seed(SEED)
prob_mask = np.random.rand(total_token_num)
# NOTE: numpy random is not thread-safe, for async DataLoader,
# using np.random.seed() directly is risky, using RandomState
# class is a better way
self_random = np.random.RandomState(SEED)
prob_mask = self_random.rand(total_token_num)
# Note: the first token is [CLS], so [low=1]
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
replace_ids = self_random.randint(1, high=vocab_size, size=total_token_num)
pre_sent_len = 0
prob_index = 0
for sent_index, sent in enumerate(batch_tokens):
......@@ -85,7 +88,9 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
# ensure at least mask one word in a sentence
while not mask_flag:
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
token_index = int(
self_random.randint(
1, high=len(sent) - 1, size=1))
if sent[token_index] != SEP and sent[token_index] != CLS:
mask_label.append(sent[token_index])
sent[token_index] = MASK
......@@ -244,13 +249,16 @@ class DataReader(object):
def build_fake_data(self):
for _ in range(1000000):
random.seed(SEED)
sent0_len = random.randint(50, 100)
sent1_len = random.randint(50, 100)
# NOTE: python random has bug in python2,
# we should avoid using random module,
# please using numpy.random
self_random = np.random.RandomState(SEED)
sent0_len = self_random.randint(50, 100)
sent1_len = self_random.randint(50, 100)
token_ids = [1] \
+ [random.randint(0, 10000) for i in range(sent0_len-1)] \
+ [random.randint(0, 10000) for i in range(sent1_len-1)] \
+ [self_random.randint(0, 10000) for i in range(sent0_len-1)] \
+ [self_random.randint(0, 10000) for i in range(sent1_len-1)] \
+ [2]
sent_ids = [0 for i in range(sent0_len)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph.dygraph_to_static import convert_call
SEED = 2020
np.random.seed(SEED)
def len_with_tensor(x):
x = fluid.dygraph.to_variable(x)
x_len = len(x)
return x_len
def len_with_lod_tensor_array(x):
x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
arr = fluid.layers.array_write(x, i=i)
arr_len = len(arr)
return arr_len
class TestLen(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.x_data = np.random.random([10, 16]).astype('float32')
self.init_func()
def init_func(self):
self.func = len_with_tensor
def _run(self, to_static):
with fluid.dygraph.guard(self.place):
if to_static:
out = declarative(self.func)(self.x_data)
else:
out = self.func(self.x_data)
if isinstance(out, fluid.core.VarBase):
out = out.numpy()
return out
def test_len(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
self.assertTrue(np.allclose(dygraph_res, static_res))
class TestLenWithTensorArray(TestLen):
def init_func(self):
self.func = len_with_lod_tensor_array
# Note: Variable(SelectedRows) is not exposed directly in dygraph.
# The unittest is used to test coverage by fake transformed code.
def len_with_selected_rows(place):
block = fluid.default_main_program().global_block()
# create selected_rows variable
var = block.create_var(
name="X",
dtype="float32",
persistable=True,
type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
# y is Variable(SelectedRows)
y = fluid.layers.merge_selected_rows(var)
y_len = convert_call(len)(y)
# z is inner tensor with shape [4, 2]
z = fluid.layers.get_tensor_from_selected_rows(y)
z_len = convert_call(len)(z)
# set data for selected_rows
x_rows = [0, 2, 2, 4, 19]
row_numel = 2
np_array = np.ones((len(x_rows), row_numel)).astype("float32")
x_var = fluid.global_scope().var("X").get_selected_rows()
x_var.set_rows(x_rows)
x_var.set_height(20)
x_tensor = x_var.get_tensor()
x_tensor.set(np_array, place)
exe = fluid.Executor(place=place)
result = exe.run(fluid.default_main_program(), fetch_list=[y_len, z_len])
return result
class TestLenWithSelectedRows(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def test_len(self):
selected_rows_var_len, var_tensor_len = len_with_selected_rows(
self.place)
self.assertEqual(selected_rows_var_len, var_tensor_len)
if __name__ == '__main__':
unittest.main()
......@@ -132,6 +132,19 @@ def var_create_in_for_loop(max_len):
return ret
def nested_for_loop_dyfunc():
two = fluid.layers.fill_constant(shape=[1], value=2, dtype="int32")
three = fluid.layers.fill_constant(shape=[1], value=3, dtype="int32")
for j in range(two):
for i in range(10):
a = 2
for i in range(three):
b = fluid.layers.zeros(shape=[1], dtype='float32')
return b
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [
......@@ -142,6 +155,8 @@ class TestNameVisitor(unittest.TestCase):
]
self.create_var_names = [set(), set(["ret"]), set()]
self.nested_for_loop_func = nested_for_loop_dyfunc
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
......@@ -155,6 +170,28 @@ class TestNameVisitor(unittest.TestCase):
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
def test_nested_loop_vars(self):
func = self.nested_for_loop_func
test_func = inspect.getsource(func)
gast_root = gast.parse(test_func)
name_visitor = NameVisitor(gast_root)
self.loop_var_names = [
set(["j", "two"]),
set(["i", "three", "b"]),
set(["i"]),
]
self.create_var_names = [set(), set(["b"]), set()]
i = 0
for node in gast.walk(gast_root):
if isinstance(node, (gast.While, gast.For)):
loop_var_names, create_var_names = name_visitor.get_loop_var_names(
node)
# print(loop_var_names)
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])
i += 1
class TestTransformWhileLoop(unittest.TestCase):
def setUp(self):
......
......@@ -172,6 +172,7 @@ class TestBoxPSPreload(unittest.TestCase):
exe.run(fluid.default_startup_program())
datasets[0].load_into_memory()
datasets[0].begin_pass()
datasets[0].slots_shuffle([])
datasets[1].preload_into_memory()
exe.train_from_dataset(
program=fluid.default_main_program(),
......
......@@ -125,6 +125,7 @@ class TestDataset(unittest.TestCase):
dataset.set_trainer_num(4)
dataset.set_hdfs_config("my_fs_name", "my_fs_ugi")
dataset.set_download_cmd("./read_from_afs my_fs_name my_fs_ugi")
dataset.enable_pv_merge()
thread_num = dataset.get_thread_num()
self.assertEqual(thread_num, 12)
......@@ -231,7 +232,7 @@ class TestDataset(unittest.TestCase):
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(10000, True)
dataset.set_fea_eval(1, True)
dataset.slots_shuffle(["slot1"])
dataset.local_shuffle()
dataset.set_generate_unique_feasigns(True, 15)
......
......@@ -26,9 +26,9 @@ def dequantize_log(x, dict_data):
output_data_f = output_data.flatten()
for i in range(x_f.size):
if x_f[i] < 0:
output_data_f[i] = -np.power(2, dict_data[x_f[i] + 128])
output_data_f[i] = -dict_data[x_f[i] + 128]
else:
output_data_f[i] = np.power(2, dict_data[x_f[i]])
output_data_f[i] = dict_data[x_f[i]]
return output_data_f.reshape(x.shape)
......
......@@ -17,7 +17,6 @@ import paddle.fluid.core as core
import os
import unittest
import paddle.fluid.layers as layers
from paddle.fluid.layers.nn import _pull_box_sparse
class TestDataFeed(unittest.TestCase):
......@@ -57,9 +56,9 @@ class TestDataFeed(unittest.TestCase):
lod_level=0,
append_batch_size=False)
emb_x, emb_y = _pull_box_sparse([x, y], size=2)
emb_xp = _pull_box_sparse(x, size=2)
concat = layers.concat([emb_x, emb_y], axis=1)
emb_x, emb_y = fluid.contrib.layers._pull_box_extended_sparse(
[x, y], size=2, extend_size=128)
concat = layers.concat([emb_x[0], emb_x[1], emb_y[0], emb_y[1]], axis=1)
fc = layers.fc(input=concat,
name="fc",
size=1,
......
......@@ -1295,6 +1295,78 @@ class TestDygraphPool2DAPIError(unittest.TestCase):
name='x1', shape=[3, 32, 32, 5], dtype="int32")
self.assertRaises(TypeError, pool2d, data2)
def test_data_format_error(self):
with program_guard(Program(), Program()):
# the data_format must be 'NCHW' or 'NHWC'
data1 = np.random.random((3, 32, 32, 5)).astype('float32')
self.assertRaises(
ValueError,
fluid.dygraph.Pool2D,
pool_size=2,
pool_type='max',
pool_stride=1,
global_pooling=False,
data_format='NWHC')
class TestDygraphPool2DAPI(unittest.TestCase):
def test_nhwc(self):
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
x = fluid.dygraph.to_variable(data)
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='max',
pool_stride=1,
pool_padding=[0, 0],
global_pooling=False,
data_format='NHWC')
out1 = pool2d(x)
out2 = pool2D_forward_naive(
data, [2, 2], [1, 1],
paddings=[0, 0],
pool_type='max',
data_format='NHWC')
self.assertTrue(np.allclose(out1.numpy(), out2))
def test_lower_case(self):
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
x = fluid.dygraph.to_variable(data)
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='max',
pool_stride=1,
pool_padding=[0, 0],
global_pooling=False,
data_format='nhwc')
out1 = pool2d(x)
out2 = pool2D_forward_naive(
data, [2, 2], [1, 1],
paddings=[0, 0],
pool_type='max',
data_format='NHWC')
self.assertTrue(np.allclose(out1.numpy(), out2))
def test_upper_case(self):
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
x = fluid.dygraph.to_variable(data)
pool2d = fluid.dygraph.Pool2D(
pool_size=2,
pool_type='MAX',
pool_stride=1,
pool_padding=[0, 0],
global_pooling=False,
data_format='nhwc')
out1 = pool2d(x)
out2 = pool2D_forward_naive(
data, [2, 2], [1, 1],
paddings=[0, 0],
pool_type='max',
data_format='NHWC')
self.assertTrue(np.allclose(out1.numpy(), out2))
if __name__ == '__main__':
unittest.main()
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
from paddle.fluid.op import Operator
class TestShapeOp(OpTest):
......@@ -45,5 +47,41 @@ class case2(TestShapeOp):
self.shape = [1, 2, 3]
class TestShapeWithSelectedRows(unittest.TestCase):
def get_places(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 1, 5, 4, 19]
height = 20
row_numel = 2
np_array = np.ones((len(x_rows), row_numel)).astype("float32")
# initialize input variable X
x = scope.var('X').get_selected_rows()
x.set_rows(x_rows)
x.set_height(height)
x_tensor = x.get_tensor()
x_tensor.set(np_array, place)
# initialize input variable Out
out_shape = scope.var("Out").get_tensor()
op = Operator("shape", Input="X", Out="Out")
op.run(scope, place)
out_shape = np.array(out_shape).tolist()
self.assertListEqual([5, 2], out_shape)
def test_check_output(self):
for place in self.get_places():
self.check_with_place(place)
if __name__ == '__main__':
unittest.main()
......@@ -50,7 +50,7 @@ class TestVarBase(unittest.TestCase):
def test_tensor_to_variable(self):
with fluid.dygraph.guard():
t = fluid.Tensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set(np.random.random((1024, 1024)), fluid.CPUPlace())
var = fluid.dygraph.to_variable(t)
self.assertTrue(np.array_equal(t, var.numpy()))
......
......@@ -314,7 +314,8 @@ class LocalSGD(Collective):
name=self.snapshot_name(param.name),
shape=param.shape,
persistable=True,
stop_gradient=True)
stop_gradient=True,
dtype=param.dtype)
block._insert_op(
idx + 1,
......
......@@ -283,6 +283,16 @@ if [ "${ADDED_OP_USE_DEFAULT_GRAD_MAKER}" != "" ]; then
check_approval 1 32832641 6836917
fi
# Get the list of PR authors with unresolved unit test issues
pip install PyGithub
# For getting PR related data
wget https://paddle-ci.gz.bcebos.com/blk/block.txt
HASUTFIXED=`python ${PADDLE_ROOT}/tools/check_ut.py | grep "has unit-test to be fixed" || true`
if [ "${HASUTFIXED}" != "" ]; then
echo_line="${HASUTFIXED} You must have one RD (chalsliu (Recommend) or kolinwei) approval.\n"
check_approval 1 45041955 22165420
fi
if [ -n "${echo_list}" ];then
echo "****************"
echo -e "${echo_list[@]}"
......
#!/bin/env python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Get pull requests. """
import os
import time
import os.path
from github import Github
class PRChecker(object):
""" PR Checker. """
def __init__(self):
self.github = Github(os.getenv('GITHUB_API_TOKEN'), timeout=60)
self.repo = None
def check(self):
""" check pr """
filename = 'block.txt'
pr_id = os.getenv('GIT_PR_ID')
if not pr_id:
print('No PR ID')
exit(0)
print(pr_id)
if not os.path.isfile(filename):
print('No author to check')
exit(0)
self.repo = self.github.get_repo('PaddlePaddle/Paddle')
pr = self.repo.get_pull(int(pr_id))
user = pr.user.login
with open(filename) as f:
for l in f:
if l.rstrip('\r\n') == user:
print('{} has UT to be fixed, so CI failed.'.format(user))
exit(1)
exit(0)
if __name__ == '__main__':
pr_checker = PRChecker()
pr_checker.check()
......@@ -45,7 +45,7 @@ function walk_dir(){
if [ $level -le 1 ]; then
enforce_scan $1"/"$file total_check_cnt valid_check_cnt
dir_name=$1
echo "${dir_name#../}"/"$file - total: ${total_check_cnt}, valid: ${valid_check_cnt}, invalid: $(($total_check_cnt-$valid_check_cnt))"
echo "${dir_name#../}/"$file" | ${total_check_cnt} | ${valid_check_cnt} | $(($total_check_cnt-$valid_check_cnt))"
ALL_PADDLE_CHECK_CNT=$(($ALL_PADDLE_CHECK_CNT+$total_check_cnt))
VALID_PADDLE_CHECK_CNT=$(($VALID_PADDLE_CHECK_CNT+$valid_check_cnt))
walk_dir $1"/"$file $level
......
......@@ -29,6 +29,15 @@
ROOT_DIR=../paddle/fluid/operators
white_list_str = "\
layer_norm_op.cc \
box_clip_op.cc \
box_clip_op.h \
random_crop_op.h \
elementwise_op_function.cu.h \
fused_elemwise_activation_op.cc \
auc_op.cu"
function enforce_scan(){
paddle_check=`grep -r -zoE "(PADDLE_ENFORCE[A-Z_]{0,9}|PADDLE_THROW)\(.[^,\);]*.[^;]*\);\s" $1 || true`
total_check_cnt=`echo "$paddle_check" | grep -cE "(PADDLE_ENFORCE|PADDLE_THROW)" || true`
......@@ -45,14 +54,16 @@ function walk_dir(){
for file in `ls $1`
do
if [ -f $1"/"$file ];then
enforce_scan $1"/"$file file_total_check_cnt file_valid_check_cnt
file_invalid_check_cnt=$(($total_check_cnt-$valid_check_cnt))
if [ $file_invalid_check_cnt -gt 0 ];then
echo "- $file | ${file_total_check_cnt} | ${file_valid_check_cnt} | ${file_invalid_check_cnt}"
in_white_list=$(echo $white_list_str | grep "${file}")
if [[ "$in_white_list" == "" ]];then
enforce_scan $1"/"$file file_total_check_cnt file_valid_check_cnt
file_invalid_check_cnt=$(($total_check_cnt-$valid_check_cnt))
if [ $file_invalid_check_cnt -gt 0 ];then
echo "- $file | ${file_total_check_cnt} | ${file_valid_check_cnt} | ${file_invalid_check_cnt}"
fi
fi
fi
if [ -d $1"/"$file ];then
dir_array[$i]=$1"/"$file
((i++))
fi
......
Dockerfile.cuda9_cudnn7_gcc48_py35_centos6
\ No newline at end of file
......@@ -3,7 +3,7 @@
# which requires some headers and symbols not present on CentOS-5 (e.g.,
# signalfd.h, pipe2, O_NONBLOCK, SOCK_NONBLOCK, etc.). See
# https://github.com/sandstorm-io/capnproto/issues/350.
FROM nvidia/cuda:9.0-cudnn7-devel-centos6
FROM nvidia/cuda:10.1-cudnn7-devel-centos6
MAINTAINER Numenta, based on the ManyLinux project
ENV LC_ALL en_US.UTF-8
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册