未验证 提交 4f0913d2 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #12075 from JiayiFeng/fix_backward_bug

Fix backward bug
...@@ -103,6 +103,11 @@ if(ANDROID OR IOS) ...@@ -103,6 +103,11 @@ if(ANDROID OR IOS)
add_definitions(-DPADDLE_MOBILE_INFERENCE) add_definitions(-DPADDLE_MOBILE_INFERENCE)
endif() endif()
if (APPLE OR WIN32)
set(WITH_MKL OFF CACHE STRING
"Disable MKL for building on mac and windows" FORCE)
endif()
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.") "A path setting third party libraries download & build directories.")
......
...@@ -49,7 +49,9 @@ cc_library(paddle_inference_api ...@@ -49,7 +49,9 @@ cc_library(paddle_inference_api
# Here the shared library doesn't depend on other fluid libraries, or double free will occur. # Here the shared library doesn't depend on other fluid libraries, or double free will occur.
cc_library(paddle_inference_api_shared SHARED cc_library(paddle_inference_api_shared SHARED
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc) SRCS paddle_inference_api.cc paddle_inference_api_impl.cc)
add_dependencies(paddle_inference_api_shared ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
set_target_properties(paddle_inference_api_shared PROPERTIES OUTPUT_NAME paddle_inference_api) set_target_properties(paddle_inference_api_shared PROPERTIES OUTPUT_NAME paddle_inference_api)
if(NOT APPLE) if(NOT APPLE)
set(LINK_FLAGS "-fPIC -fvisibility=hidden") set(LINK_FLAGS "-fPIC -fvisibility=hidden")
set_target_properties(paddle_inference_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}") set_target_properties(paddle_inference_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
......
...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) ...@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
......
...@@ -13,29 +13,61 @@ ...@@ -13,29 +13,61 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include <deque>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ReaderBase::~ReaderBase() {}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {} void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> lock(mu_);
void FileReader::ReadNext(std::vector<LoDTensor> *out) { PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning);
ReadNextImpl(out); ReadNextImpl(out);
if (out->empty()) { }
return;
void ReaderBase::InsertDecoratedReader(
const std::shared_ptr<ReaderBase> &decorated_reader) {
std::lock_guard<std::mutex> guard(mu_);
decorated_readers_.emplace_back(decorated_reader);
}
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue;
queue.emplace_back(this);
while (!queue.empty()) { // BFS search
auto *front = queue.front();
queue.pop_front();
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (auto &reader : front->decorated_readers_) {
if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
}
}
} }
PADDLE_ENFORCE_EQ(out->size(), dims_.size()); return result;
for (size_t i = 0; i < dims_.size(); ++i) { }
auto &actual = (*out)[i].dims();
auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size()); void ReaderBase::Shutdown() {
for (int j = 0; j < actual.size(); ++j) { std::lock_guard<std::mutex> lock(mu_);
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1); if (status_ != ReaderStatus::kStopped) {
ShutdownImpl();
status_ = ReaderStatus::kStopped;
} }
}
void ReaderBase::Start() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kRunning) {
StartImpl();
status_ = ReaderStatus::kRunning;
} }
} }
ReaderBase::~ReaderBase() { Shutdown(); }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
...@@ -24,61 +25,116 @@ ...@@ -24,61 +25,116 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase { class ReaderBase {
public: public:
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
void Start();
virtual void ReInit() = 0; // Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
virtual ~ReaderBase(); virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ShutdownImpl() {}
virtual void StartImpl() {}
ReaderStatus status_{kRunning};
mutable std::mutex mu_;
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(
const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader.
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
}; };
class DecoratedReader : public ReaderBase { class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public: public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader) explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) { : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
} }
void ReInit() override { reader_->ReInit(); } void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected: protected:
std::shared_ptr<ReaderBase> reader_; void ShutdownImpl() override { reader_->Shutdown(); }
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override; void StartImpl() override { reader_->Start(); }
protected: std::shared_ptr<ReaderBase> reader_;
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
private:
std::vector<DDim> dims_;
}; };
// FileReader is just a conceptual class.
class FileReader : public ReaderBase {};
// The ReaderHolder is used as reader' unified wrapper, // The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables. // making it easier to access different type reader in Variables.
class ReaderHolder { class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out); reader_->ReadNext(out);
} }
void ReInit() {
void ResetAll() {
auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
}
void Shutdown() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit(); reader_->Start();
} }
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private: private:
std::shared_ptr<ReaderBase> reader_; std::shared_ptr<ReaderBase> reader_;
}; };
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {}
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
class StubRootReader : public paddle::framework::ReaderBase {
public:
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
{
auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
}
...@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) ...@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
......
...@@ -20,15 +20,19 @@ namespace reader { ...@@ -20,15 +20,19 @@ namespace reader {
class BatchReader : public framework::DecoratedReader { class BatchReader : public framework::DecoratedReader {
public: public:
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size) BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
: DecoratedReader(reader), batch_size_(batch_size) { bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private: private:
int batch_size_; int batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_; std::vector<std::vector<framework::LoDTensor>> buffer_;
}; };
...@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase { ...@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<BatchReader>(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size"))); underlying_reader, Attr<int>("batch_size"),
Attr<bool>("discard_leftover")));
} }
}; };
...@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr<int>("batch_size", AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.") "How many instances the batch reader yields each time.")
.GreaterThan(0); .GreaterThan(0);
AddAttr<bool>("discard_leftover",
"If true, the leftover instances that are not enough for a "
"new batch will be discarded.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CreateBatchReader Operator CreateBatchReader Operator
...@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear(); buffer_.clear();
buffer_.reserve(batch_size_); buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
...@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break; break;
} }
} }
if (discard_leftover_ && buffer_.size() < batch_size_) {
buffer_.clear();
}
// Concat instances // Concat instances
out->clear(); out->clear();
if (buffer_.empty()) { if (buffer_.empty()) {
......
...@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader { ...@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_(source_var_names), source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {} sink_var_names_(sink_var_names) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private: private:
const framework::ProgramDesc program_; const framework::ProgramDesc program_;
...@@ -60,8 +60,8 @@ class CreateCustomReaderOp : public framework::OperatorBase { ...@@ -60,8 +60,8 @@ class CreateCustomReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<CustomReader>(
new CustomReader(underlying_reader.Get(), *sub_block, underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"), Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names"))); Attr<std::vector<std::string>>("sink_var_names")));
} }
...@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference { ...@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
} }
}; };
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) { void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->clear(); out->clear();
std::vector<framework::LoDTensor> underlying_outs; std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs); reader_->ReadNext(&underlying_outs);
......
...@@ -23,13 +23,13 @@ namespace reader { ...@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same // 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2. // time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 5; static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training: // There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel // 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by // 2. the one just be received from the channel, which is also being used by
// subsequent operators. // subsequent operators.
// So the channel size should be kChacheSize - 2 // So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2 static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader { class DoubleBufferReader : public framework::DecoratedReader {
public: public:
...@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader { ...@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher(); StartPrefetcher();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~DoubleBufferReader() { EndPrefetcher(); } ~DoubleBufferReader() { EndPrefetcher(); }
private: private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() { void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize); channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); }); prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
...@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num)); place = platform::CUDAPlace(static_cast<int>(num));
} }
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
} }
}; };
...@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id; size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) { if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { ...@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
} }
} }
void DoubleBufferReader::ReInit() {
reader_->ReInit();
EndPrefetcher();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() { void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts."; VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0; size_t cached_tensor_id = 0;
......
...@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader { ...@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num) MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out); reader_->ReadNext(out);
if (out->empty()) { if (out->empty() && pass_count_ < pass_num_ - 1) {
++pass_count_; reader_->Shutdown();
if (pass_count_ < pass_num_) { reader_->Start();
reader_->ReInit();
reader_->ReadNext(out); reader_->ReadNext(out);
} ++pass_count_;
} }
} }
void ReInit() override { private:
void StartImpl() override {
pass_count_ = 0; pass_count_ = 0;
reader_->ReInit(); reader_->Start();
} }
private:
int pass_num_; int pass_num_;
mutable int pass_count_; mutable int pass_count_;
}; };
...@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { ...@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num"); int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
} }
}; };
......
...@@ -19,22 +19,27 @@ namespace paddle { ...@@ -19,22 +19,27 @@ namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class PyReader : public framework::ReaderBase { class PyReader : public framework::FileReader {
public: public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) { explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue; queue_ = queue;
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
bool success; bool success;
*out = queue_->Pop(&success); *out = queue_->Pop(&success);
if (!success) out->clear(); if (!success) out->clear();
} }
void ReInit() override {}
private: private:
void ShutdownImpl() override { /* TODO */
}
void StartImpl() override { /* TODO */
}
std::shared_ptr<LoDTensorBlockingQueue> queue_; std::shared_ptr<LoDTensorBlockingQueue> queue_;
}; };
...@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase { ...@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const std::string& queue_name = Input("blocking_queue"); const std::string& queue_name = Input("blocking_queue");
auto* queue_holder_var = scope.FindVar(queue_name); auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE( PADDLE_ENFORCE_NOT_NULL(
queue_holder_var != nullptr, queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found", "No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name); queue_name);
auto* queue_holder = auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>(); queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue())); out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
} }
}; };
......
...@@ -19,11 +19,11 @@ namespace operators { ...@@ -19,11 +19,11 @@ namespace operators {
namespace reader { namespace reader {
template <typename T> template <typename T>
class RandomDataGenerator : public framework::ReaderBase { class RandomDataGenerator : public framework::FileReader {
public: public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low, RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
float high) float high)
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) { : framework::FileReader(), low_(low), high_(high), shapes_(shapes) {
PADDLE_ENFORCE_LE(low, high, PADDLE_ENFORCE_LE(low, high,
"'low' shouldn't be greater than 'high'.(%f vs %f)", low, "'low' shouldn't be greater than 'high'.(%f vs %f)", low,
high); high);
...@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_ = std::uniform_real_distribution<float>(low_, high_); dist_ = std::uniform_real_distribution<float>(low_, high_);
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear(); out->clear();
out->reserve(shapes_.size()); out->reserve(shapes_.size());
for (const framework::DDim& shape : shapes_) { for (const framework::DDim& shape : shapes_) {
...@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase { ...@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
} }
} }
void ReInit() override { return; }
private: private:
float low_; float low_;
float high_; float high_;
...@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase { ...@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"), out->Reset(std::make_shared<RandomDataGenerator<T>>(
Attr<float>("high"))); shapes, Attr<float>("low"), Attr<float>("high")));
} }
}; };
......
...@@ -21,10 +21,8 @@ namespace reader { ...@@ -21,10 +21,8 @@ namespace reader {
template <bool ThreadSafe> template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader { class RecordIOFileReader : public framework::FileReader {
public: public:
explicit RecordIOFileReader(const std::string& filename, explicit RecordIOFileReader(const std::string& filename)
const std::vector<framework::DDim>& dims) : scanner_(filename),
: FileReader(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get( dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) { platform::CPUPlace())) {
if (ThreadSafe) { if (ThreadSafe) {
...@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename; LOG(INFO) << "Creating file reader" << filename;
} }
void ReInit() override { scanner_.Reset(); }
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) { if (ThreadSafe) {
...@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
} }
} }
void StartImpl() override { scanner_.Reset(); }
private: private:
std::unique_ptr<std::mutex> mutex_; std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_; recordio::Scanner scanner_;
...@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase { ...@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename"); std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>( out->Reset(std::make_shared<RecordIOFileReader<true>>(filename));
filename, RestoreShapes(shape_concat, ranks)));
} }
}; };
......
...@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer(); ReloadBuffer();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear(); out->clear();
if (iteration_pos_ >= buffer_.size()) { if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer"; VLOG(10) << "Resetting shuffle buffer";
...@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
} }
private: private:
void ShutdownImpl() override {
buffer_.clear();
iteration_pos_ = 0;
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
ReloadBuffer();
}
void ReloadBuffer() { void ReloadBuffer() {
buffer_.clear(); buffer_.clear();
buffer_.reserve(buffer_size_); buffer_.reserve(buffer_size_);
...@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase { ...@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
} }
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>(); ->Get<framework::ReaderHolder>();
out->Reset( out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
new ShuffleReader(underlying_reader.Get(), underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
static_cast<size_t>(Attr<int>("buffer_size"))));
} }
}; };
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override { reader_->ReInit(); }
private:
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(new ThreadedReader(underlying_reader.Get()));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
protected:
void Apply() override {
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
...@@ -23,24 +23,26 @@ namespace reader { ...@@ -23,24 +23,26 @@ namespace reader {
class MultiFileReader : public framework::ReaderBase { class MultiFileReader : public framework::ReaderBase {
public: public:
MultiFileReader(const std::vector<std::string>& file_names, MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
const std::vector<framework::DDim>& dims, size_t thread_num,
size_t buffer_size) size_t buffer_size)
: buffer_size_(buffer_size) { : buffer_size_(buffer_size) {
readers_.reserve(file_names.size()); readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) { for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims)); readers_.emplace_back(CreateReaderByFileName(f_name));
} }
prefetchers_.resize(thread_num); prefetchers_.resize(thread_num);
StartNewScheduler(); StartNewScheduler();
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override; void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
~MultiFileReader() { EndScheduler(); } ~MultiFileReader() { EndScheduler(); }
private: private:
void ShutdownImpl() override { EndScheduler(); }
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler(); void StartNewScheduler();
void EndScheduler(); void EndScheduler();
void ScheduleThreadFunc(); void ScheduleThreadFunc();
...@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase { ...@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_; reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) { void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) { if (!buffer_->Receive(out)) {
out->clear(); out->clear();
} }
} }
void MultiFileReader::ReInit() {
EndScheduler();
StartNewScheduler();
}
void MultiFileReader::StartNewScheduler() { void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size(); size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size()); waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
...@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() { ...@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
} }
} }
} }
// If users invoke ReInit() when scheduler is running, it will close the // If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler // 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends. // to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) { for (auto& p : prefetchers_) {
...@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { ...@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins); reader->ReadNext(&ins);
if (ins.empty()) { if (ins.empty()) {
reader->ReInit(); reader->Shutdown();
reader->Start();
break; break;
} }
try { try {
...@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names, out->Reset(
RestoreShapes(shape_concat, ranks), std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
thread_num, buffer_size));
} }
}; };
......
...@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() { ...@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) { const std::string& file_name) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator); size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos, PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: " "File name illegal! A legal file name should be like: "
...@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( ...@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto itor = FileReaderRegistry().find(filetype); auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(), PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype); "No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims); framework::ReaderBase* reader = (itor->second)(file_name);
return std::unique_ptr<framework::ReaderBase>(reader); return std::unique_ptr<framework::ReaderBase>(reader);
} }
......
...@@ -25,22 +25,21 @@ namespace reader { ...@@ -25,22 +25,21 @@ namespace reader {
static constexpr char kFileFormatSeparator[] = "."; static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*( using FileReaderCreator =
const std::string&, const std::vector<framework::DDim>&)>; std::function<framework::ReaderBase*(const std::string&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry(); std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader> template <typename Reader>
int RegisterFileReader(const std::string& filetype) { int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = []( FileReaderRegistry()[filetype] = [](const std::string& fn) {
const std::string& fn, const std::vector<framework::DDim>& dims) { return new Reader(fn);
return new Reader(fn, dims);
}; };
return 0; return 0;
} }
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName( std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims); const std::string& file_name);
extern std::vector<framework::DDim> RestoreShapes( extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks); const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
...@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "") py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ReInit); .def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue = using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue; ::paddle::operators::reader::LoDTensorBlockingQueue;
......
...@@ -375,9 +375,6 @@ def open_recordio_file(filename, ...@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if pass_num > 1: if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num) main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
...@@ -529,9 +526,6 @@ def open_files(filenames, ...@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader = multi_pass( main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num) reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader) return monkey_patch_reader_methods(main_prog_reader)
...@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num): ...@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)}) 'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(reader): def read_file(reader):
""" """
Execute the given reader and get data via it. Execute the given reader and get data via it.
......
...@@ -156,12 +156,15 @@ if '${WITH_MKL}' == 'ON': ...@@ -156,12 +156,15 @@ if '${WITH_MKL}' == 'ON':
shutil.copy('${MKLML_IOMP_LIB}', libs_path) shutil.copy('${MKLML_IOMP_LIB}', libs_path)
package_data['paddle.libs']+=['libmklml_intel.so','libiomp5.so'] package_data['paddle.libs']+=['libmklml_intel.so','libiomp5.so']
if '${WITH_MKLDNN}' == 'ON': if '${WITH_MKLDNN}' == 'ON':
# TODO(typhoonzero): use install_name_tool to patch mkl libs once
# we can support mkl on mac.
#
# change rpath of libmkldnn.so.0, add $ORIGIN/ to it. # change rpath of libmkldnn.so.0, add $ORIGIN/ to it.
# The reason is that all thirdparty libraries in the same directory, # The reason is that all thirdparty libraries in the same directory,
# thus, libmkldnn.so.0 will find libmklml_intel.so and libiomp5.so. # thus, libmkldnn.so.0 will find libmklml_intel.so and libiomp5.so.
command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}" command = "patchelf --set-rpath '$ORIGIN/' ${MKLDNN_SHARED_LIB}"
if os.system(command) != 0: if os.system(command) != 0:
raise Exception("patchelf --set-rpath for libmkldnn.so.0 fails") raise Exception("patch libmkldnn.so failed, command: %s" % command)
package_data['paddle.libs']+=['libmkldnn.so.0'] package_data['paddle.libs']+=['libmkldnn.so.0']
shutil.copy('${MKLDNN_SHARED_LIB}', libs_path) shutil.copy('${MKLDNN_SHARED_LIB}', libs_path)
# remove unused paddle/libs/__init__.py # remove unused paddle/libs/__init__.py
...@@ -172,9 +175,12 @@ package_dir['paddle.libs']=libs_path ...@@ -172,9 +175,12 @@ package_dir['paddle.libs']=libs_path
# The reason is that libwarpctc.so, libiomp5.so etc are in paddle.libs, and # The reason is that libwarpctc.so, libiomp5.so etc are in paddle.libs, and
# core.so is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries. # core.so is in paddle.fluid, thus paddle/fluid/../libs will pointer to above libraries.
# This operation will fix https://github.com/PaddlePaddle/Paddle/issues/3213 # This operation will fix https://github.com/PaddlePaddle/Paddle/issues/3213
command = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so" if "@APPLE@" == "1":
command = "install_name_tool -id \"@loader_path/../libs/\" ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so"
else:
command = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}/python/paddle/fluid/core.so"
if os.system(command) != 0: if os.system(command) != 0:
raise Exception("patchelf --set-rpath for core.so fails") raise Exception("patch core.so failed, command: %s" % command)
setup(name='${PACKAGE_NAME}', setup(name='${PACKAGE_NAME}',
version='${PADDLE_VERSION}', version='${PADDLE_VERSION}',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册