提交 2d9bd762 编写于 作者: L Luo Tao

Merge branch 'develop' into demo

......@@ -7,7 +7,17 @@ set(ANAKIN_INSTALL_DIR "${THIRD_PARTY_PATH}/install/anakin" CACHE PATH
set(ANAKIN_INCLUDE "${ANAKIN_INSTALL_DIR}" CACHE STRING "root of Anakin header files")
set(ANAKIN_LIBRARY "${ANAKIN_INSTALL_DIR}" CACHE STRING "path of Anakin library")
set(ANAKIN_COMPILE_EXTRA_FLAGS -Wno-error=unused-variable -Wno-error=format-extra-args -Wno-error=comment -Wno-error=format -Wno-error=switch -Wno-error=return-type -Wno-error=non-virtual-dtor -Wno-reorder -Wno-error=cpp)
set(ANAKIN_COMPILE_EXTRA_FLAGS
-Wno-error=unused-variable -Wno-unused-variable
-Wno-error=format-extra-args -Wno-format-extra-args
-Wno-error=comment -Wno-comment
-Wno-error=format -Wno-format
-Wno-error=switch -Wno-switch
-Wno-error=return-type -Wno-return-type
-Wno-error=non-virtual-dtor -Wno-non-virtual-dtor
-Wno-sign-compare
-Wno-reorder
-Wno-error=cpp)
set(ANAKIN_LIBRARY_URL "https://github.com/pangge/Anakin/releases/download/3.0/anakin_release_simple.tar.gz")
......
# Get the latest git tag.
set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD")
set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
while ("${PADDLE_VERSION}" STREQUAL "")
execute_process(
COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 ${tmp_version}
COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 --always ${tmp_version}
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_TAG_NAME
RESULT_VARIABLE GIT_RESULT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ${GIT_RESULT})
# Check the tag is a correct version
if (${GIT_TAG_NAME} MATCHES "v[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
# if no tag was found, set PADDLE_VERSION to latest
set(PADDLE_VERSION "latest")
elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name.
set(tmp_version "${GIT_TAG_NAME}~1")
......
......@@ -28,9 +28,9 @@
### 准备预测模型
准备预测模型部分,我们以手写数字识别任务为例进行介绍。手写数字识别任务定义了一个含有[两个隐层的简单全连接网络](https://github.com/PaddlePaddle/book/blob/develop/02.recognize_digits/README.cn.md#softmax回归softmax-regression),网络接受一幅图片作为输入,将图片分类到 0 ~ 9 类别标签之一。完整代码可以查看[此目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense) 中的相关脚本。
准备预测模型部分,我们以手写数字识别任务为例进行介绍。手写数字识别任务定义了一个含有[两个隐层的简单全连接网络](https://github.com/PaddlePaddle/book/blob/develop/02.recognize_digits/README.cn.md#softmax回归softmax-regression),网络接受一幅图片作为输入,将图片分类到 0 ~ 9 类别标签之一。完整代码可以查看[此目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense) 中的相关脚本。
调用C-API开发预测程序需要一个训练好的模型,运行[MNIST手写数字识别目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)下的[mnist_v2.py](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/examples/model_inference/dense/mnist_v2.py)脚本,在终端执行`python mnist_v2.py`,会使用 PaddlePaddle 内置的 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)进行训练。训练好的模型默认保存在当前运行目录下的`models`目录中。
调用C-API开发预测程序需要一个训练好的模型,运行[MNIST手写数字识别目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense)下的[mnist_v2.py](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/examples/model_inference/dense/mnist_v2.py)脚本,在终端执行`python mnist_v2.py`,会使用 PaddlePaddle 内置的 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)进行训练。训练好的模型默认保存在当前运行目录下的`models`目录中。
下面,我们将训练结束后存储下来的模型转换成预测模型。
......@@ -48,7 +48,7 @@
dump_v2_config(predict, "trainer_config.bin", True)
```
对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,[`mnist_v2.py`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/mnist_v2.py)脚本集成了序列化神经网络结构的过程,可以直接运行 `python mnist_v2.py --task dump_config` 对神经网络结构进行序列化,结果会写入当前运行目录下的`trainer_config.bin`文件中。
对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense)这个示例,[`mnist_v2.py`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense/mnist_v2.py)脚本集成了序列化神经网络结构的过程,可以直接运行 `python mnist_v2.py --task dump_config` 对神经网络结构进行序列化,结果会写入当前运行目录下的`trainer_config.bin`文件中。
使用这种方式,需要**在运行时将神经网络的多个可学习参数放在同一个目录中**,C-API可以通过分别指定序列化后的网络结构文件和参数目录来加载训练好的模型。
......@@ -68,7 +68,7 @@
merge_v2_model(net, param_file, output_file)
```
对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。
对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。
#### 注意事项
1. 为使用C-API,在调用`dump_v2_config`序列化神经网络结构时,参数`binary`必须指定为`True`
......@@ -77,10 +77,10 @@
### 编写预测代码
预测代码更多详细示例代码请参考[C-API使用示例](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference) 目录下的代码示例。这一节对图1中预测代码编写的5个步骤进行介绍和说明。
预测代码更多详细示例代码请参考[C-API使用示例](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference) 目录下的代码示例。这一节对图1中预测代码编写的5个步骤进行介绍和说明。
#### step 1. 初始化PaddlePaddle运行环境
第一步需调用[`paddle_init`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/main.h#L27) 初始化PaddlePaddle运行环境,该接口接受两个参数:参数的个数和参数列表。
第一步需调用[`paddle_init`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/main.h#L27) 初始化PaddlePaddle运行环境,该接口接受两个参数:参数的个数和参数列表。
#### step2. 加载模型
......@@ -88,8 +88,8 @@
概念上,在 PaddlePaddle 内部,一个GradientMachine类的对象管理着一组计算层(PaddlePaddle Layers)来完成前向和反向计算,并处理与之相关的所有细节。在调用C-API预测时,只需进行前向计算而无需调用反向计算。这篇文档之后部分会使用`gradient machine`来特指调用PaddlePaddle C-API创建的GradientMachine类的对象。每一个 `gradient machine` 都会管理维护一份训练好的模型,下面是C-API提供的,两种常用的模型加载方式:
1. 调用[`paddle_gradient_machine_load_parameter_from_disk`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/gradient_machine.h#L61)接口,从磁盘加载预测模型。这时`gradient machine`会独立拥有一份训练好的模型;
1. 调用[`paddle_gradient_machine_create_shared_param`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/gradient_machine.h#L88)接口,与其它`gradient machine`的共享已经加载的预测模型。这种情况多出现在使用多线程预测时,通过多个线程共享同一个模型来减少内存开销。可参考[此示例](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/examples/model_inference/multi_thread/main.c)
1. 调用[`paddle_gradient_machine_load_parameter_from_disk`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/gradient_machine.h#L61)接口,从磁盘加载预测模型。这时`gradient machine`会独立拥有一份训练好的模型;
1. 调用[`paddle_gradient_machine_create_shared_param`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/gradient_machine.h#L88)接口,与其它`gradient machine`的共享已经加载的预测模型。这种情况多出现在使用多线程预测时,通过多个线程共享同一个模型来减少内存开销。可参考[此示例](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/examples/model_inference/multi_thread/main.c)
- 注意事项
......@@ -117,7 +117,7 @@ C-API支持的所有输入数据类型和他们的组织方式,请参考“输
#### step 4. 前向计算
完成上述准备之后,通过调用 [`paddle_gradient_machine_forward`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/gradient_machine.h#L73) 接口完成神经网络的前向计算。
完成上述准备之后,通过调用 [`paddle_gradient_machine_forward`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/gradient_machine.h#L73) 接口完成神经网络的前向计算。
#### step 5. 清理
......
......@@ -249,7 +249,7 @@ void MainThreadsImageClassification(bool use_gpu) {
const size_t len = local_outputs[0].data.length();
float* data = static_cast<float*>(local_outputs[0].data.data());
float* ref_data = refs[tid].data<float>();
EXPECT_EQ(refs[tid].numel(), len / sizeof(float));
EXPECT_EQ((size_t)refs[tid].numel(), len / sizeof(float));
for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3);
}
......
......@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
......
......@@ -21,8 +21,8 @@ namespace framework {
// a static local variable is already being initialized.
// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex
OpInfoMap& OpInfoMap::Instance() {
static OpInfoMap* g_op_info_map = new OpInfoMap();
return *g_op_info_map;
static OpInfoMap g_op_info_map;
return g_op_info_map;
}
} // namespace framework
} // namespace paddle
......@@ -13,29 +13,61 @@
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <deque>
namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReader::ReadNext(std::vector<LoDTensor> *out) {
void ReaderBase::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> lock(mu_);
PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning);
ReadNextImpl(out);
if (out->empty()) {
return;
}
}
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = (*out)[i].dims();
auto &expect = dims_[i];
void ReaderBase::InsertDecoratedReader(
const std::shared_ptr<ReaderBase> &decorated_reader) {
std::lock_guard<std::mutex> guard(mu_);
decorated_readers_.emplace_back(decorated_reader);
}
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
// PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
std::unordered_set<ReaderBase *> ReaderBase::GetEndPoints() {
std::unordered_set<ReaderBase *> result;
std::deque<ReaderBase *> queue;
queue.emplace_back(this);
while (!queue.empty()) { // BFS search
auto *front = queue.front();
queue.pop_front();
if (front->decorated_readers_.empty()) {
result.emplace(front);
} else {
for (auto &reader : front->decorated_readers_) {
if (auto *reader_ptr = reader.lock().get()) {
queue.emplace_back(reader_ptr);
}
}
}
}
return result;
}
void ReaderBase::Shutdown() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kStopped) {
ShutdownImpl();
status_ = ReaderStatus::kStopped;
}
}
void ReaderBase::Start() {
std::lock_guard<std::mutex> lock(mu_);
if (status_ != ReaderStatus::kRunning) {
StartImpl();
status_ = ReaderStatus::kRunning;
}
}
ReaderBase::~ReaderBase() { Shutdown(); }
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@
#pragma once
#include <memory>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
......@@ -24,61 +25,116 @@
namespace paddle {
namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
void ReadNext(std::vector<LoDTensor>* out);
void Shutdown();
virtual void ReInit() = 0;
void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
std::unordered_set<ReaderBase*> GetEndPoints();
virtual ~ReaderBase();
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
virtual void ShutdownImpl() {}
virtual void StartImpl() {}
ReaderStatus status_{kRunning};
mutable std::mutex mu_;
private:
friend class DecoratedReader;
// These methods can be only invoked inside DecoratedReader to record the
// decorating chain.
void InsertDecoratedReader(
const std::shared_ptr<ReaderBase>& decorated_reader);
// A set of which readers that decorated this reader.
std::vector<std::weak_ptr<ReaderBase>> decorated_readers_;
};
class DecoratedReader : public ReaderBase {
class DecoratedReader : public ReaderBase,
public std::enable_shared_from_this<DecoratedReader> {
public:
explicit DecoratedReader(const std::shared_ptr<ReaderBase>& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
void ReInit() override { reader_->ReInit(); }
void RegisterDecorateChain() {
reader_->InsertDecoratedReader(shared_from_this());
}
protected:
std::shared_ptr<ReaderBase> reader_;
};
class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
void ShutdownImpl() override { reader_->Shutdown(); }
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
void StartImpl() override { reader_->Start(); }
private:
std::vector<DDim> dims_;
std::shared_ptr<ReaderBase> reader_;
};
// FileReader is just a conceptual class.
class FileReader : public ReaderBase {};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
public:
void Reset(ReaderBase* reader) { reader_.reset(reader); }
template <typename T>
void Reset(const std::shared_ptr<T>& reader) {
auto reader_base = std::dynamic_pointer_cast<ReaderBase>(reader);
PADDLE_ENFORCE_NOT_NULL(reader_base);
reader_ = reader_base;
}
std::shared_ptr<ReaderBase> Get() const { return reader_; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
void ResetAll() {
auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) {
reader->Shutdown();
}
for (auto* reader : end_readers) {
reader->Start();
}
}
void Shutdown() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown();
}
void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
reader_->Start();
}
operator const std::shared_ptr<ReaderBase>&() const { return this->reader_; }
private:
std::shared_ptr<ReaderBase> reader_;
};
template <typename T, typename... ARGS>
inline std::shared_ptr<DecoratedReader> MakeDecoratedReader(ARGS&&... args) {
std::shared_ptr<DecoratedReader> reader(new T(std::forward<ARGS>(args)...));
reader->RegisterDecorateChain();
return reader;
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
#include <memory>
#include "gtest/gtest.h"
class StubDecoratedReader : public paddle::framework::DecoratedReader {
public:
explicit StubDecoratedReader(const std::shared_ptr<ReaderBase> &reader)
: DecoratedReader(reader) {}
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
class StubRootReader : public paddle::framework::ReaderBase {
public:
void ReadNextImpl(std::vector<paddle::framework::LoDTensor> *out) override {}
};
TEST(READER, decorate_chain) {
auto root = std::make_shared<StubRootReader>();
auto end_point1 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
auto end_point2 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
{
auto end_point3 =
paddle::framework::MakeDecoratedReader<StubDecoratedReader>(root);
ASSERT_EQ(root->GetEndPoints().size(), 3U);
}
{ ASSERT_EQ(root->GetEndPoints().size(), 2U); }
}
......@@ -27,7 +27,7 @@ TEST_F(DFG_Tester, Init) {
DataFlowGraph graph;
pass.Run(&graph);
// Analysis is sensitive to ProgramDesc, careful to change the original model.
ASSERT_EQ(graph.nodes.size(), 37);
ASSERT_EQ(graph.nodes.size(), 37UL);
pass.Finalize();
LOG(INFO) << '\n' << graph.DotString();
}
......
......@@ -82,7 +82,7 @@ TEST_F(DFG_Tester, Fuse) {
// At least one nodes should be deleted.
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
ASSERT_EQ(6UL, count1);
ASSERT_EQ(6, count1);
}
} // namespace analysis
......
......@@ -19,8 +19,9 @@ namespace paddle {
namespace memory {
namespace detail {
BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
size_t min_chunk_size, size_t max_chunk_size)
BuddyAllocator::BuddyAllocator(
std::unique_ptr<SystemAllocator> system_allocator, size_t min_chunk_size,
size_t max_chunk_size)
: min_chunk_size_(min_chunk_size),
max_chunk_size_(max_chunk_size),
cache_(system_allocator->UseGpu()),
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <tuple>
......@@ -32,8 +33,8 @@ namespace detail {
class BuddyAllocator {
public:
BuddyAllocator(SystemAllocator* system_allocator, size_t min_chunk_size,
size_t max_chunk_size);
BuddyAllocator(std::unique_ptr<SystemAllocator> system_allocator,
size_t min_chunk_size, size_t max_chunk_size);
~BuddyAllocator();
......@@ -103,7 +104,7 @@ class BuddyAllocator {
private:
/*! Allocate CPU/GPU memory from system */
SystemAllocator* system_allocator_;
std::unique_ptr<SystemAllocator> system_allocator_;
std::mutex mutex_;
};
......
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "glog/logging.h"
......@@ -34,12 +36,15 @@ namespace memory {
using BuddyAllocator = detail::BuddyAllocator;
BuddyAllocator* GetCPUBuddyAllocator() {
static std::once_flag init_flag;
static detail::BuddyAllocator* a = nullptr;
if (a == nullptr) {
a = new detail::BuddyAllocator(new detail::CPUAllocator,
platform::CpuMinChunkSize(),
platform::CpuMaxChunkSize());
}
std::call_once(init_flag, []() {
a = new detail::BuddyAllocator(
std::unique_ptr<detail::SystemAllocator>(new detail::CPUAllocator),
platform::CpuMinChunkSize(), platform::CpuMaxChunkSize());
});
return a;
}
......@@ -68,27 +73,33 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
#ifdef PADDLE_WITH_CUDA
BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
static BuddyAllocator** as = NULL;
if (as == NULL) {
static std::once_flag init_flag;
static detail::BuddyAllocator** a_arr = nullptr;
std::call_once(init_flag, [gpu_id]() {
int gpu_num = platform::GetCUDADeviceCount();
as = new BuddyAllocator*[gpu_num];
for (int gpu = 0; gpu < gpu_num; gpu++) {
as[gpu] = nullptr;
PADDLE_ENFORCE(gpu_id < gpu_num, "gpu_id:%d should < gpu_num:%d", gpu_id,
gpu_num);
a_arr = new BuddyAllocator*[gpu_num];
for (int i = 0; i < gpu_num; i++) {
a_arr[i] = nullptr;
platform::SetDeviceId(i);
a_arr[i] = new BuddyAllocator(
std::unique_ptr<detail::SystemAllocator>(new detail::GPUAllocator(i)),
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100
<< "% of GPU memory.\n"
<< "You can set GFlags environment variable '"
<< "FLAGS_fraction_of_gpu_memory_to_use"
<< "' to change the fraction of GPU usage.\n\n";
}
}
});
platform::SetDeviceId(gpu_id);
if (!as[gpu_id]) {
as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator(gpu_id),
platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize());
VLOG(10) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100
<< "% of GPU memory.\n"
<< "You can set GFlags environment variable '"
<< "FLAGS_fraction_of_gpu_memory_to_use"
<< "' to change the fraction of GPU usage.\n\n";
}
return as[gpu_id];
return a_arr[gpu_id];
}
template <>
......@@ -125,12 +136,16 @@ void Free<platform::CUDAPlace>(platform::CUDAPlace place, void* p) {
}
BuddyAllocator* GetCUDAPinnedBuddyAllocator() {
static BuddyAllocator* ba = NULL;
if (ba == NULL) {
ba = new BuddyAllocator(new detail::CUDAPinnedAllocator,
static std::once_flag init_flag;
static BuddyAllocator* ba = nullptr;
std::call_once(init_flag, []() {
ba = new BuddyAllocator(std::unique_ptr<detail::SystemAllocator>(
new detail::CUDAPinnedAllocator),
platform::CUDAPinnedMinChunkSize(),
platform::CUDAPinnedMaxChunkSize());
}
});
return ba;
}
......
......@@ -216,6 +216,18 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean_e.setZero();
saved_variance_e.setZero();
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), C);
if ((N * sample_size) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y);
return;
}
switch (data_layout) {
case DataLayout::kNCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
......@@ -247,10 +259,6 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
PADDLE_THROW("Unknown storage order: %s", data_layout_str);
}
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), C);
running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr =
......@@ -427,6 +435,11 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_bias_arr.setZero();
d_scale_arr.setZero();
if ((N * sample_size) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
return;
}
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
switch (data_layout) {
......
......@@ -72,6 +72,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
auto *y = ctx.Output<Tensor>("Y");
y->mutable_data<T>(ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
......@@ -93,7 +96,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
VLOG(1) << "Setting descriptors.";
VLOG(3) << "Setting descriptors.";
std::vector<int> dims;
std::vector<int> strides;
if (data_layout == DataLayout::kNCHW) {
......@@ -113,11 +116,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
auto *y = ctx.Output<Tensor>("Y");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
......@@ -162,22 +160,28 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
if ((N * H * W * D) == 1) {
LOG(WARNING) << "Only 1 element in normalization dimension, "
<< "we skip the batch norm calculation, let y = x.";
framework::TensorCopySync(*x, ctx.GetPlace(), y);
} else {
double this_factor = 1. - momentum;
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
bias->template data<BatchNormParamType<T>>(), this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
}
}
// clean when exit.
......@@ -209,6 +213,25 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) {
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
return;
}
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
......@@ -247,21 +270,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>();
const void *saved_var_data = saved_var->template data<T>();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
......
......@@ -205,9 +205,10 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
context->SetOutputsDim(framework::GradVarName("Params"),
context->GetInputsDim("Params"));
}
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X")));
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
if (context->HasOutputs(framework::GradVarName("X"))) {
context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X"));
}
}
};
......
......@@ -124,8 +124,7 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Tensor<float/double> with shape [N x D].");
AddOutput("Y",
"(Tensor, default Tensor<float>), a 2-D tensor with shape "
"[N x 1]. The cross entropy loss.")
.Reuse("X");
"[N x 1]. The cross entropy loss.");
AddAttr<bool>("soft_label",
"(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels.")
......
......@@ -44,8 +44,10 @@ class MergeLoDTensorOp : public framework::OperatorBase {
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
auto level = static_cast<size_t>(Attr<int>("level"));
auto &mask_dim = mask.dims();
PADDLE_ENFORCE(in_true.numel() || in_false.numel(),
"Input(InTrue) or Input(InFalse) should be initialized.");
auto &mask_dim = mask.dims();
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) {
cpu_mask->ShareDataWith(mask);
......@@ -59,19 +61,27 @@ class MergeLoDTensorOp : public framework::OperatorBase {
}
auto *mask_data = cpu_mask->data<bool>();
int rank = in_true.dims().size();
platform::Place place = in_true.place();
std::type_index data_type = in_true.type();
framework::DDim in_true_dims =
framework::slice_ddim(in_true.dims(), 1, rank);
platform::Place place = dev_place;
int64_t batch_size = in_true.dims()[0] + in_false.dims()[0];
auto in_true_dim_vec = framework::vectorize(in_true_dims);
in_true_dim_vec.insert(in_true_dim_vec.begin(), batch_size);
std::type_index data_type =
in_true.IsInitialized() ? in_true.type() : in_false.type();
int rank;
framework::DDim in_dims;
if (in_true.IsInitialized()) {
rank = in_true.dims().size();
in_dims = framework::slice_ddim(in_true.dims(), 1, rank);
} else {
rank = in_false.dims().size();
in_dims = framework::slice_ddim(in_false.dims(), 1, rank);
}
auto in_dim_vec = framework::vectorize(in_dims);
in_dim_vec.insert(in_dim_vec.begin(), batch_size);
framework::DDim out_dims = framework::make_ddim(in_true_dim_vec);
framework::DDim out_dims = framework::make_ddim(in_dim_vec);
out->Resize(out_dims);
out->mutable_data(place, data_type);
auto *out_lod = out->mutable_lod();
......
......@@ -22,7 +22,6 @@ reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_threaded_reader_op SRCS create_threaded_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
......
......@@ -20,15 +20,19 @@ namespace reader {
class BatchReader : public framework::DecoratedReader {
public:
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
BatchReader(const std::shared_ptr<ReaderBase>& reader, int batch_size,
bool discard_leftover)
: DecoratedReader(reader),
batch_size_(batch_size),
discard_leftover_(discard_leftover) {
buffer_.reserve(batch_size_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
int batch_size_;
bool discard_leftover_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
};
......@@ -46,8 +50,9 @@ class CreateBatchReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
out->Reset(framework::MakeDecoratedReader<BatchReader>(
underlying_reader, Attr<int>("batch_size"),
Attr<bool>("discard_leftover")));
}
};
......@@ -57,6 +62,10 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
AddAttr<bool>("discard_leftover",
"If true, the leftover instances that are not enough for a "
"new batch will be discarded.")
.SetDefault(true);
AddComment(R"DOC(
CreateBatchReader Operator
......@@ -66,7 +75,7 @@ class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
......@@ -77,6 +86,9 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
break;
}
}
if (discard_leftover_ && buffer_.size() < batch_size_) {
buffer_.clear();
}
// Concat instances
out->clear();
if (buffer_.empty()) {
......
......@@ -33,7 +33,7 @@ class CustomReader : public framework::DecoratedReader {
source_var_names_(source_var_names),
sink_var_names_(sink_var_names) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
const framework::ProgramDesc program_;
......@@ -60,10 +60,10 @@ class CreateCustomReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new CustomReader(underlying_reader.Get(), *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
out->Reset(framework::MakeDecoratedReader<CustomReader>(
underlying_reader, *sub_block,
Attr<std::vector<std::string>>("source_var_names"),
Attr<std::vector<std::string>>("sink_var_names")));
}
};
......@@ -143,7 +143,7 @@ class CustomReaderInferVarType : public framework::VarTypeInference {
}
};
void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out->clear();
std::vector<framework::LoDTensor> underlying_outs;
reader_->ReadNext(&underlying_outs);
......
......@@ -23,13 +23,13 @@ namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 5;
static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 3; // kCacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
......@@ -50,12 +50,21 @@ class DoubleBufferReader : public framework::DecoratedReader {
StartPrefetcher();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~DoubleBufferReader() { EndPrefetcher(); }
private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
......@@ -109,7 +118,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num));
}
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>(
underlying_reader, place));
}
};
......@@ -136,7 +146,7 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
}
};
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) {
......@@ -150,12 +160,6 @@ void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
}
}
void DoubleBufferReader::ReInit() {
reader_->ReInit();
EndPrefetcher();
StartPrefetcher();
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0;
......
......@@ -24,23 +24,22 @@ class MultiPassReader : public framework::DecoratedReader {
MultiPassReader(const std::shared_ptr<ReaderBase>& reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
reader_->ReadNext(out);
if (out->empty()) {
if (out->empty() && pass_count_ < pass_num_ - 1) {
reader_->Shutdown();
reader_->Start();
reader_->ReadNext(out);
++pass_count_;
if (pass_count_ < pass_num_) {
reader_->ReInit();
reader_->ReadNext(out);
}
}
}
void ReInit() override {
private:
void StartImpl() override {
pass_count_ = 0;
reader_->ReInit();
reader_->Start();
}
private:
int pass_num_;
mutable int pass_count_;
};
......@@ -60,7 +59,8 @@ class CreateMultiPassReaderOp : public framework::OperatorBase {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
int pass_num = Attr<int>("pass_num");
out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num));
out->Reset(framework::MakeDecoratedReader<MultiPassReader>(
underlying_reader, pass_num));
}
};
......
......@@ -19,22 +19,27 @@ namespace paddle {
namespace operators {
namespace reader {
class PyReader : public framework::ReaderBase {
class PyReader : public framework::FileReader {
public:
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue) {
explicit PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
: framework::FileReader() {
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
queue_ = queue;
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
bool success;
*out = queue_->Pop(&success);
if (!success) out->clear();
}
void ReInit() override {}
private:
void ShutdownImpl() override { /* TODO */
}
void StartImpl() override { /* TODO */
}
std::shared_ptr<LoDTensorBlockingQueue> queue_;
};
......@@ -51,14 +56,14 @@ class CreatePyReaderOp : public framework::OperatorBase {
const std::string& queue_name = Input("blocking_queue");
auto* queue_holder_var = scope.FindVar(queue_name);
PADDLE_ENFORCE(
queue_holder_var != nullptr,
PADDLE_ENFORCE_NOT_NULL(
queue_holder_var,
"No LoDTensorBlockingQueueHolder variable with name %s found",
queue_name);
auto* queue_holder =
queue_holder_var->template GetMutable<LoDTensorBlockingQueueHolder>();
out->Reset(new PyReader(queue_holder->GetQueue()));
out->Reset(std::make_shared<PyReader>(queue_holder->GetQueue()));
}
};
......
......@@ -19,11 +19,11 @@ namespace operators {
namespace reader {
template <typename T>
class RandomDataGenerator : public framework::ReaderBase {
class RandomDataGenerator : public framework::FileReader {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
float high)
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) {
: framework::FileReader(), low_(low), high_(high), shapes_(shapes) {
PADDLE_ENFORCE_LE(low, high,
"'low' shouldn't be greater than 'high'.(%f vs %f)", low,
high);
......@@ -32,7 +32,7 @@ class RandomDataGenerator : public framework::ReaderBase {
dist_ = std::uniform_real_distribution<float>(low_, high_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear();
out->reserve(shapes_.size());
for (const framework::DDim& shape : shapes_) {
......@@ -51,8 +51,6 @@ class RandomDataGenerator : public framework::ReaderBase {
}
}
void ReInit() override { return; }
private:
float low_;
float high_;
......@@ -79,8 +77,8 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"),
Attr<float>("high")));
out->Reset(std::make_shared<RandomDataGenerator<T>>(
shapes, Attr<float>("low"), Attr<float>("high")));
}
};
......
......@@ -21,10 +21,8 @@ namespace reader {
template <bool ThreadSafe>
class RecordIOFileReader : public framework::FileReader {
public:
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReader(dims),
scanner_(filename),
explicit RecordIOFileReader(const std::string& filename)
: scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {
if (ThreadSafe) {
......@@ -33,8 +31,6 @@ class RecordIOFileReader : public framework::FileReader {
LOG(INFO) << "Creating file reader" << filename;
}
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
if (ThreadSafe) {
......@@ -45,6 +41,8 @@ class RecordIOFileReader : public framework::FileReader {
}
}
void StartImpl() override { scanner_.Reset(); }
private:
std::unique_ptr<std::mutex> mutex_;
recordio::Scanner scanner_;
......@@ -58,20 +56,11 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
static_cast<int>(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::string filename = Attr<std::string>("filename");
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader<true>(
filename, RestoreShapes(shape_concat, ranks)));
out->Reset(std::make_shared<RecordIOFileReader<true>>(filename));
}
};
......
......@@ -34,7 +34,7 @@ class ShuffleReader : public framework::DecoratedReader {
ReloadBuffer();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
out->clear();
if (iteration_pos_ >= buffer_.size()) {
VLOG(10) << "Resetting shuffle buffer";
......@@ -47,6 +47,17 @@ class ShuffleReader : public framework::DecoratedReader {
}
private:
void ShutdownImpl() override {
buffer_.clear();
iteration_pos_ = 0;
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
ReloadBuffer();
}
void ReloadBuffer() {
buffer_.clear();
buffer_.reserve(buffer_size_);
......@@ -86,9 +97,8 @@ class CreateShuffleReaderOp : public framework::OperatorBase {
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(
new ShuffleReader(underlying_reader.Get(),
static_cast<size_t>(Attr<int>("buffer_size"))));
out->Reset(framework::MakeDecoratedReader<ShuffleReader>(
underlying_reader, static_cast<size_t>(Attr<int>("buffer_size"))));
}
};
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ThreadedReader : public framework::DecoratedReader {
public:
explicit ThreadedReader(const std::shared_ptr<ReaderBase>& reader)
: DecoratedReader(reader) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
std::lock_guard<std::mutex> lock(mutex_);
reader_->ReadNext(out);
}
void ReInit() override { reader_->ReInit(); }
private:
std::mutex mutex_;
};
class CreateThreadedReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
auto* out = detail::Ref(scope.FindVar(Output("Out")))
.GetMutable<framework::ReaderHolder>();
if (out->Get() != nullptr) {
return;
}
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
out->Reset(new ThreadedReader(underlying_reader.Get()));
}
};
class CreateThreadedReaderOpMaker : public DecoratedReaderMakerBase {
protected:
void Apply() override {
AddComment(R"DOC(
CreateThreadedReader Operator
This operator creates a threaded reader. A threaded reader's
'ReadNext()' can be invoked by several threads at the same
time.
When the attribute 'safe_mode' is true, the threaded reader's
'ReInit()' is disabled to avoid unexpected bugs in multi-thread
environment.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace reader = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_threaded_reader,
reader::CreateThreadedReaderOp,
reader::CreateThreadedReaderOpMaker);
......@@ -23,24 +23,26 @@ namespace reader {
class MultiFileReader : public framework::ReaderBase {
public:
MultiFileReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num,
MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num,
size_t buffer_size)
: buffer_size_(buffer_size) {
readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) {
readers_.emplace_back(CreateReaderByFileName(f_name, dims));
readers_.emplace_back(CreateReaderByFileName(f_name));
}
prefetchers_.resize(thread_num);
StartNewScheduler();
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
void ReInit() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~MultiFileReader() { EndScheduler(); }
private:
void ShutdownImpl() override { EndScheduler(); }
void StartImpl() override { StartNewScheduler(); }
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
......@@ -55,17 +57,12 @@ class MultiFileReader : public framework::ReaderBase {
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
};
void MultiFileReader::ReadNext(std::vector<framework::LoDTensor>* out) {
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
if (!buffer_->Receive(out)) {
out->clear();
}
}
void MultiFileReader::ReInit() {
EndScheduler();
StartNewScheduler();
}
void MultiFileReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size());
......@@ -120,7 +117,7 @@ void MultiFileReader::ScheduleThreadFunc() {
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
......@@ -138,7 +135,8 @@ void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (ins.empty()) {
reader->ReInit();
reader->Shutdown();
reader->Start();
break;
}
try {
......@@ -180,9 +178,8 @@ class OpenFilesOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultiFileReader(file_names,
RestoreShapes(shape_concat, ranks),
thread_num, buffer_size));
out->Reset(
std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
}
};
......
......@@ -39,7 +39,7 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
const std::string& file_name) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
......@@ -49,7 +49,7 @@ std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
framework::ReaderBase* reader = (itor->second)(file_name);
return std::unique_ptr<framework::ReaderBase>(reader);
}
......
......@@ -25,22 +25,21 @@ namespace reader {
static constexpr char kFileFormatSeparator[] = ".";
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
using FileReaderCreator =
std::function<framework::ReaderBase*(const std::string&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dims);
FileReaderRegistry()[filetype] = [](const std::string& fn) {
return new Reader(fn);
};
return 0;
}
std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);
const std::string& file_name);
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......
......@@ -296,7 +296,7 @@ All parameter, weight, gradient are variables in Paddle.
py::return_value_policy::reference);
py::class_<framework::ReaderHolder>(m, "Reader", "")
.def("reset", &framework::ReaderHolder::ReInit);
.def("reset", &framework::ReaderHolder::ResetAll);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
......
......@@ -375,9 +375,6 @@ def open_recordio_file(filename,
if pass_num > 1:
main_prog_var = multi_pass(reader=main_prog_var, pass_num=pass_num)
if for_parallel:
main_prog_var = parallel(reader=main_prog_var)
return monkey_patch_reader_methods(main_prog_var)
......@@ -529,9 +526,6 @@ def open_files(filenames,
main_prog_reader = multi_pass(
reader=main_prog_reader, pass_num=pass_num)
if for_parallel:
main_prog_reader = parallel(reader=main_prog_reader)
return monkey_patch_reader_methods(main_prog_reader)
......@@ -647,11 +641,6 @@ def multi_pass(reader, pass_num):
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def parallel(reader):
return __create_shared_decorated_reader__('create_threaded_reader', reader,
{})
def read_file(reader):
"""
Execute the given reader and get data via it.
......
......@@ -14,10 +14,11 @@
import paddle
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard, default_main_program, default_startup_program
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import MomentumOptimizer
import paddle.fluid.core as core
import paddle.fluid as fluid
import unittest
import numpy as np
......@@ -31,14 +32,13 @@ class TestMNISTIfElseOp(unittest.TestCase):
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant_batch_size_like(
input=label, dtype='int64', shape=[1], value=5.0)
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = layers.less_than(x=label, y=limit)
true_image, false_image = layers.split_lod_tensor(
input=image, mask=cond)
true_out = layers.create_tensor(dtype='float32')
true_cond = layers.ConditionalBlock([true_image])
true_cond = layers.ConditionalBlock([cond])
with true_cond.block():
hidden = layers.fc(input=true_image, size=100, act='tanh')
......@@ -46,7 +46,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
layers.assign(input=prob, output=true_out)
false_out = layers.create_tensor(dtype='float32')
false_cond = layers.ConditionalBlock([false_image])
false_cond = layers.ConditionalBlock([cond])
with false_cond.block():
hidden = layers.fc(input=false_image, size=200, act='tanh')
......@@ -64,7 +64,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=200)
batch_size=10)
place = core.CPUPlace()
exe = Executor(place)
......@@ -94,8 +94,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant_batch_size_like(
input=label, dtype='int64', shape=[1], value=5.0)
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = layers.less_than(x=label, y=limit)
ie = layers.IfElse(cond)
......@@ -125,7 +124,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
place = core.CPUPlace()
exe = Executor(place)
exe.run(kwargs['startup_program'])
exe.run(startup_prog)
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
......@@ -133,7 +132,7 @@ class TestMNISTIfElseOp(unittest.TestCase):
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape((y_data.shape[0], 1))
outs = exe.run(kwargs['main_program'],
outs = exe.run(prog,
feed={'x': x_data,
'y': y_data},
fetch_list=[avg_loss])
......@@ -143,6 +142,67 @@ class TestMNISTIfElseOp(unittest.TestCase):
self.assertFalse(True)
class TestIfElse(unittest.TestCase):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 0.5
self.data = np.random.rand(25, 1).astype(np.float32)
def compare_ifelse_op_and_numpy(self, place):
self.set_test_case()
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
src = layers.data(name='data', shape=[1], dtype='float32')
cond = layers.fill_constant(
[1], dtype='float32', value=self.cond_value)
ifcond = layers.less_than(x=src, y=cond)
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
ie.output(true_target)
with ie.false_block():
false_target = ie.input(src)
ie.output(false_target)
if_out = ie()
out = layers.reduce_sum(if_out)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fetch_list = [out]
o1, = exe.run(fluid.default_main_program(),
feed={'data': self.data},
fetch_list=[out])
o2 = np.sum(self.data)
self.assertTrue(
np.allclose(
o1, o2, atol=1e-8),
"IfElse result : " + str(o1) + "\n Numpy result :" + str(o2))
def test_cpu(self):
self.compare_ifelse_op_and_numpy(fluid.CPUPlace())
def test_cuda(self):
if not core.is_compiled_with_cuda():
return
self.compare_ifelse_op_and_numpy(fluid.CUDAPlace(0))
class TestIfElseTrueBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 10.
self.data = np.random.rand(25, 1).astype(np.float32)
class TestIfElseFalseBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = -10.
self.data = np.random.rand(25, 1).astype(np.float32)
if __name__ == '__main__':
# temp disable if else unittest since it could be buggy.
exit(0)
unittest.main()
......@@ -40,7 +40,6 @@ class TestFakeDequantizeMaxAbsOp(OpTest):
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype("float32")
yq, scale = quantize_max_abs(x, self.num_bits)
print 'scale ', scale
ydq = dequantize_max_abs(yq, self.num_bits, scale)
self.inputs = {'X': yq}
......
......@@ -113,7 +113,9 @@ class BaseParallelForTest(unittest.TestCase):
generator = callback()
# Automatically insert parallel do if use_parallel = True
if use_parallel:
places = fluid.layers.get_places()
thread_num = fluid.core.get_cuda_device_count(
) if use_gpu else 8
places = fluid.layers.get_places(thread_num)
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
data = next(generator)
......
from setuptools import setup, Distribution, Extension
import subprocess
import shutil
import os
import re
import shutil
class BinaryDistribution(Distribution):
def has_ext_modules(foo):
return True
MAJOR = 0
MINOR = 14
PATCH = 0
RC = 0
ISTAGED = False
......@@ -22,14 +19,47 @@ def git_commit():
git_commit = 'Unknown'
return git_commit
def _get_version_detail(idx):
assert idx < 3, "vesion info consists of %(major)d.%(minor)d.%(patch)d, \
so detail index must less than 3"
if re.match('@TAG_VERSION_REGEX@', '@PADDLE_VERSION@'):
version_details = '@PADDLE_VERSION@'.split('.')
if len(version_details) == 3:
return version_details[idx]
return 0
def get_major():
return int(_get_version_detail(0))
def get_minor():
return int(_get_version_detail(1))
def get_patch():
return str(_get_version_detail(2))
def is_taged():
try:
cmd = ['git', 'describe', '--exact-match', '--tags']
git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE).communicate()[0].strip()
except:
return False
if git_tag.replace('v', '') == '@PADDLE_VERSION@':
return True
else:
return False
def write_version_py(filename='paddle/version.py'):
cnt = '''
# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY
#
full_version = '%(major)d.%(minor)d.%(patch)d'
full_version = '%(major)d.%(minor)d.%(patch)s'
major = '%(major)d'
minor = '%(minor)d'
patch = '%(patch)d'
patch = '%(patch)s'
rc = '%(rc)d'
istaged = %(istaged)s
commit = '%(commit)s'
......@@ -51,13 +81,13 @@ def mkl():
commit = git_commit()
with open(filename, 'w') as f:
f.write(cnt % {
'major': MAJOR,
'minor': MINOR,
'patch': PATCH,
'major': get_major(),
'minor': get_minor(),
'patch': get_patch(),
'rc': RC,
'version': '${PADDLE_VERSION}',
'commit': commit,
'istaged': ISTAGED,
'istaged': is_taged(),
'with_mkl': '@WITH_MKL@'})
write_version_py(filename='@PADDLE_BINARY_DIR@/python/paddle/version.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册