提交 a8c076e5 编写于 作者: Y Yu Yang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/shuffle_reader

......@@ -244,11 +244,11 @@ function(cc_test TARGET_NAME)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS})
# Support linking flags: --whole-archive (Linux) / -force_load (MacOS)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
target_circle_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
if("${cc_test_DEPS}" MATCHES "ARCHIVE_START")
list(REMOVE_ITEM cc_test_DEPS ARCHIVE_START ARCHIVE_END)
endif()
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
......@@ -311,8 +311,8 @@ function(nv_test TARGET_NAME)
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags glog)
add_test(${TARGET_NAME} ${TARGET_NAME})
endif()
endfunction(nv_test)
......
......@@ -59,6 +59,17 @@ After converting:
queue. It will block until the queue has the required number of
tensors.
### Sparse Update
For embedding layers, the gradient may have many rows containing only 0 when training,
if the gradient uses a dense tensor to do parameter optimization,
it could spend unnecessary memory, slow down the calculations and waste
the bandwidth while doing distributed training.
In Fluid, we introduce [SelectedRows](../selected_rows.md) to represent a list of rows containing
non-zero gradient data. So when we do parameter optimization both locally and remotely,
we only need to send those non-zero rows to the optimizer operators:
<img src="src/sparse_update.png" width="700" />
### Benefits
......@@ -91,6 +102,6 @@ After converting:
`min_count` attribute), does our current design support it? (similar
question for the *Add* OP)
### References
### References:
[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
开发标准
========
PaddlePaddle遵守如下三个部分的代码和文档规范。
PaddlePaddle使用git做版本管理,docker作为构建和测试环境。代码中包含了Cuda, C++, Python, Shell等多种编程语言。语言规范遵守Google C++ Style, Pep-8, 代码库中包含自动化检查工具做风格检查。代码注释需要遵守Doxygen规范,不满足风格要求的代码会编译失败。关于如何使用git, 构建测试及代码开发, 我们提供了如下指南。
.. toctree::
:maxdepth: 1
contribute_to_paddle_cn.md
PaddlePaddle面向国内外用户,包含了中文和英文两部分的文档。设计文档和issue问题描述都推荐使用英文。对于设计文档,重在问题描述,背景阐述,然后才是解决方案。文档由Sphinx生成,因此代码注释也需要符合Sphinx文档标准。推荐本地使用paddlepaddle.org工具编译生成和预览文档,请参阅如下文档。
.. toctree::
:maxdepth: 1
write_docs_cn.rst
PaddlePaddle V2 使用新增Layer方式定义新的操作。组合基础API可以实现多种复杂Layer, 满足绝大多数应用。如需要定制Layer,请参阅如下文档,欢迎提交patch。
.. toctree::
:maxdepth: 1
new_layer_cn.rst
......@@ -71,6 +71,13 @@ paddle.init(
- trainer_id:**必选,默认0**,每个trainer的唯一ID,从0开始的整数
- pservers:**必选,默认127.0.0.1**,当前训练任务启动的pserver的IP列表,多个IP使用“,”隔开
```python
trainer = paddle.trainer.SGD(..., is_local=False)
```
参数说明
- is_local: **必选, 默认True**, 是否使用PServer更新参数
## 准备数据集
......
......@@ -73,6 +73,14 @@ Parameter Description
- trainer_id: **required, default 0**, ID for every trainer, start from 0.
- pservers: **required, default 127.0.0.1**, list of IPs of parameter servers, separated by ",".
```python
trainer = paddle.trainer.SGD(..., is_local=False)
```
Parameter Description
- is_local: **required, default True**, whether update parameters by PServer.
## Prepare Training Dataset
Here's some example code [prepare.py](https://github.com/PaddlePaddle/Paddle/tree/develop/doc/howto/usage/cluster/src/word2vec/prepare.py), it will download public `imikolov` dataset and split it into multiple files according to job parallelism(trainers count). Modify `SPLIT_COUNT` at the begining of `prepare.py` to change the count of output files.
......
......@@ -135,6 +135,14 @@ OpDesc *BlockDesc::PrependOp() {
return ops_.front().get();
}
OpDesc *BlockDesc::InsertOp(size_t index) {
need_update_ = true;
auto it = ops_.begin() + index;
std::unique_ptr<OpDesc> new_op(new OpDesc(this));
it = ops_.insert(it, std::move(new_op));
return (*it).get();
}
void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
......
......@@ -87,6 +87,8 @@ class BlockDesc {
OpDesc *PrependOp();
OpDesc *InsertOp(size_t index);
void RemoveOp(size_t s, size_t e);
std::vector<OpDesc *> AllOps() const;
......
......@@ -34,6 +34,15 @@ DEFINE_bool(check_nan_inf, false,
namespace paddle {
namespace framework {
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
framework::ProgramDesc prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
Executor::Executor(const platform::Place& place) : place_(place) {}
static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
......@@ -85,73 +94,9 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size());
auto& block = pdesc.Block(block_id);
Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------";
}
auto* ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx;
}
// Check whether the block already has feed operators and feed_holder.
......@@ -313,5 +258,81 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
delete copy_program;
}
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return ctx;
}
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);
Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_);
}
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
}
if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------";
}
}
} // namespace framework
} // namespace paddle
......@@ -22,7 +22,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
struct ExecutorPrepareContext;
class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
......@@ -38,8 +38,8 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true);
void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
......@@ -47,6 +47,13 @@ class Executor {
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
bool create_vars = true);
private:
const platform::Place place_;
};
......
......@@ -74,6 +74,9 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
platform::SetDeviceId(dev_id);
#endif
}
// profile
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
RunImpl(scope, place);
}
......@@ -481,9 +484,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Get(place);
// profile
platform::RecordEvent record_event(Type(), dev_ctx);
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
......
......@@ -17,5 +17,57 @@
namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}
std::vector<std::unique_ptr<ReaderBase>> ReaderBase::SplitReader(
const platform::PlaceList &places) {
std::vector<std::unique_ptr<ReaderBase>> readers;
auto mutex = std::make_shared<std::mutex>();
for (size_t i = 0; i < places.size(); ++i) {
readers.emplace_back(new ThreadSafeReader(this, mutex));
}
return readers;
}
void ThreadSafeReader::ReadNext(std::vector<LoDTensor> *out) {
std::lock_guard<std::mutex> guard(*mutex_);
reader_->ReadNext(out);
}
void ThreadSafeReader::ReInit() {
std::lock_guard<std::mutex> guard(*mutex_);
reader_->ReInit();
}
bool ThreadSafeReader::HasNext() const {
std::lock_guard<std::mutex> guard(*mutex_);
return reader_->HasNext();
}
std::vector<std::unique_ptr<ReaderBase>> ThreadSafeReader::SplitReader(
const platform::PlaceList &places) {
std::vector<std::unique_ptr<ReaderBase>> readers;
for (size_t i = 0; i < places.size(); ++i) {
readers.emplace_back(new ThreadSafeReader(reader_, mutex_));
}
return readers;
}
FileReaderBase::FileReaderBase(const std::vector<DDim> &dims) : dims_(dims) {}
void FileReaderBase::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];
PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -16,6 +16,11 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <thread>
#include <vector>
namespace paddle {
namespace framework {
......@@ -28,6 +33,9 @@ class ReaderBase {
virtual bool HasNext() const = 0;
virtual std::vector<std::unique_ptr<ReaderBase>> SplitReader(
const platform::PlaceList& places);
virtual ~ReaderBase();
};
......@@ -45,6 +53,37 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_;
};
class ThreadSafeReader : public DecoratedReader {
public:
ThreadSafeReader(ReaderBase* reader, const std::shared_ptr<std::mutex>& mutex)
: DecoratedReader(reader), mutex_(mutex) {}
void ReadNext(std::vector<LoDTensor>* out) override;
void ReInit() override;
bool HasNext() const override;
std::vector<std::unique_ptr<ReaderBase>> SplitReader(
const platform::PlaceList& places) override;
private:
std::shared_ptr<std::mutex> mutex_;
};
class FileReaderBase : public ReaderBase {
public:
explicit FileReaderBase(const std::vector<DDim>& dims);
void ReadNext(std::vector<LoDTensor>* out) override;
protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
private:
std::vector<DDim> dims_;
};
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
......@@ -53,8 +92,14 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
void ReInit() { reader_->ReInit(); }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
}
bool HasNext() const { return reader_->HasNext(); }
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"
......@@ -102,6 +103,18 @@ void Scope::DeleteScope(Scope* scope) {
}
}
void Scope::EraseVars(std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
delete it->second;
it = vars_.erase(it);
} else {
++it;
}
}
}
void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const {
auto origin_it = vars_.find(origin_name);
......
......@@ -51,6 +51,8 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);
void EraseVars(std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;
......
......@@ -115,11 +115,11 @@ void TestInference(const std::string& dirname,
#endif
}
// Enable the profiler
paddle::platform::EnableProfiler(state);
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
// Enable the profiler
paddle::platform::EnableProfiler(state);
{
paddle::platform::RecordEvent record_event(
"init_program",
......@@ -143,6 +143,10 @@ void TestInference(const std::string& dirname,
inference_program = paddle::inference::Load(executor, *scope, dirname);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"load_program_profiler.txt");
paddle::platform::ResetProfiler();
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
......@@ -165,6 +169,12 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
// Enable the profiler
paddle::platform::EnableProfiler(state);
// Run repeat times to profile the performance
for (int i = 0; i < repeat; ++i) {
paddle::platform::RecordEvent record_event(
......@@ -173,12 +183,13 @@ void TestInference(const std::string& dirname,
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
}
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
"profiler.txt");
paddle::platform::ResetProfiler();
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(
paddle::platform::EventSortingKey::kDefault,
"run_inference_profiler.txt");
paddle::platform::ResetProfiler();
}
delete scope;
}
......@@ -222,8 +222,6 @@ cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
if(WITH_GPU)
cc_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif()
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class DeleteVarOp : public framework::OperatorBase {
public:
DeleteVarOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
dev_ctx.Wait();
auto delete_var_names = Inputs("X");
const_cast<framework::Scope &>(scope).EraseVars(delete_var_names);
}
};
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
public:
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of delete op").AsDuplicable();
AddComment(R"DOC(
Delete Operator.
It should not be configured by users directly.
)DOC");
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(delete_var, paddle::operators::DeleteVarOp,
paddle::framework::EmptyGradOpMaker,
paddle::operators::DeleteVarOpInfoMaker);
......@@ -85,4 +85,4 @@ google::protobuf::int64 GrpcByteBufferSource::ByteCount() const {
} // namespace detail
} // namespace operators
} // namespace paddle
\ No newline at end of file
} // namespace paddle
......@@ -82,7 +82,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
DestroyCallback destroy_callback = [](void* backing) {};
void* buf = malloc(1024);
void* payload;
void* payload = nullptr;
size_t payload_size;
ProtoEncodeHelper e((char*)buf, 1024);
e.WriteString(VarMsg::kVarnameFieldNumber, name);
......@@ -297,4 +297,4 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
} // namespace detail
} // namespace operators
} // namespace paddle
\ No newline at end of file
} // namespace paddle
......@@ -273,7 +273,6 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
std::map<int, std::vector<std::pair<T, int>>>& false_pos,
const int class_num) const {
constexpr T kEPS = static_cast<T>(1e-6);
const int* pos_count_data = input_pos_count.data<int>();
for (int i = 0; i < class_num; ++i) {
label_pos_count[i] = pos_count_data[i];
......@@ -282,12 +281,11 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
auto SetData = [](const framework::LoDTensor& pos_tensor,
std::map<int, std::vector<std::pair<T, int>>>& pos) {
const T* pos_data = pos_tensor.data<T>();
auto pos_data_lod = pos_tensor.lod();
for (size_t i = 0; i < pos_data_lod.size(); ++i) {
for (size_t j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
auto pos_data_lod = pos_tensor.lod()[0];
for (size_t i = 0; i < pos_data_lod.size() - 1; ++i) {
for (size_t j = pos_data_lod[i]; j < pos_data_lod[i + 1]; ++j) {
T score = pos_data[j * 2];
int flag = 1;
if (pos_data[j * 2 + 1] < kEPS) flag = 0;
int flag = pos_data[j * 2 + 1];
pos[i].push_back(std::make_pair(score, flag));
}
}
......
......@@ -29,8 +29,11 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OPERATOR(elementwise_add, ops::ElementwiseOp,
ops::ElementwiseAddOpMaker, ops::ElementwiseOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -41,6 +41,16 @@ class ElementwiseOp : public framework::OperatorWithKernel {
}
};
class ElementwiseOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto x_var = op_desc.Input("X")[0];
auto out_var = op_desc.Output("Out")[0];
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
}
};
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
......@@ -45,6 +45,9 @@ void gemm<platform::CUDADeviceContext, float16>(
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, N));
......@@ -106,6 +109,9 @@ void gemm<platform::CUDADeviceContext, float16>(
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
h_A, lda, &h_beta, h_C, ldc));
......@@ -251,6 +257,9 @@ void batched_gemm<platform::CUDADeviceContext, float16>(
const half* h_B = reinterpret_cast<const half*>(B);
half* h_C = reinterpret_cast<half*>(C);
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas Hgemm requires GPU compute capability >= 53");
PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount));
......
......@@ -72,6 +72,11 @@ TEST(math_function, notrans_mul_trans_fp16) {
CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
......@@ -149,6 +154,11 @@ TEST(math_function, trans_mul_notrans_fp16) {
CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place);
fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5});
......@@ -248,6 +258,11 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
int m = 2;
int n = 3;
int k = 3;
......@@ -355,6 +370,11 @@ TEST(math_function, gemm_trans_cublas_fp16) {
CUDAPlace gpu_place(0);
CUDADeviceContext context(gpu_place);
// fp16 GEMM in cublas requires GPU compute capability >= 53
if (context.GetComputeCapability() < 53) {
return;
}
int m = 2;
int n = 3;
int k = 3;
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
namespace paddle {
namespace operators {
......
......@@ -14,19 +14,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <mutex>
#include <thread>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -41,26 +37,35 @@ USE_CUDA_ONLY_OP(ncclBcast);
namespace f = paddle::framework;
namespace p = paddle::platform;
static std::vector<int> gpu_list;
// test data amount
const f::DDim kDims = {100, 100};
const f::DDim kDims = {20, 20};
// nccl op common tester, init communicator.
class NCCLTester : public ::testing::Test {
public:
virtual void SetUp() override {
int count = p::GetCUDADeviceCount();
if (count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< count;
exit(0);
}
for (int i = 0; i < count; ++i) {
gpu_list_.emplace_back(i);
}
paddle::platform::CPUPlace cpu_place;
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
p::CUDAPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
dev_ctxs_.emplace_back(new p::CUDADeviceContext(place));
}
NCCLInitOp();
}
virtual void TearDown() override {
for (auto &device_context : dev_ctxs) {
for (auto &device_context : dev_ctxs_) {
delete device_context;
}
}
......@@ -70,36 +75,40 @@ class NCCLTester : public ::testing::Test {
std::unique_ptr<f::OpDesc> op1(new f::OpDesc);
op1->SetType("ncclInit");
op1->SetInput("parallel_scopes", {"p_scopes"});
op1->SetOutput("Communicator", {"comm"});
op1->SetAttr("gpus", {gpu_list});
auto *var = g_scope.Var("comm");
auto *var = g_scope_.Var("comm");
var->GetMutable<p::Communicator>();
auto *scope_var = g_scope_.Var("p_scopes");
auto *p_scopes = scope_var->GetMutable<std::vector<f::Scope *>>();
(*p_scopes).resize(gpu_list_.size());
auto op = f::OpRegistry::CreateOp(*op1);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, cpu_place);
op->Run(g_scope_, cpu_place);
VLOG(1) << "NCCLInitOp finished.";
}
int GetGPUData(int gpu_id) { return gpu_id + 42; }
template <class T>
void PerThreadProgram(int gpu_id, const f::OpDesc &op_desc, f::Scope *scope) {
std::unique_lock<std::mutex> lk(mu);
std::unique_lock<std::mutex> lk(mu_);
const f::OpDesc *op1 = &op_desc;
p::CUDAPlace place(gpu_id);
auto &ctx = dev_ctxs.at(gpu_id);
auto &ctx = dev_ctxs_.at(gpu_id);
auto *send_tensor = scope->Var("st")->GetMutable<f::LoDTensor>();
auto *recv_tensor = scope->Var("rt")->GetMutable<f::LoDTensor>();
if (!send_tensor->numel()) {
send_tensor->Resize(kDims);
send_tensor->mutable_data<T>(kDims, place);
std::vector<T> send_vector(f::product(kDims), gpu_id);
std::vector<T> send_vector(f::product(kDims), GetGPUData(gpu_id));
paddle::framework::TensorFromVector<T>(send_vector, *ctx, send_tensor);
ctx->Wait();
VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel();
}
......@@ -118,30 +127,14 @@ class NCCLTester : public ::testing::Test {
}
public:
std::vector<p::DeviceContext *> dev_ctxs;
f::Scope g_scope;
std::mutex mu;
std::vector<p::DeviceContext *> dev_ctxs_;
f::Scope g_scope_;
std::mutex mu_;
std::vector<int> gpu_list_;
};
// ncclInitOp with desc
TEST(NCCL, ncclInitOp) {
std::unique_ptr<f::OpDesc> op_desc(new f::OpDesc);
op_desc->SetType("ncclInit");
op_desc->SetOutput("Communicator", {"x1"});
op_desc->SetAttr("gpus", {gpu_list});
f::Scope g_scope;
paddle::platform::CPUPlace cpu_place;
auto *var = g_scope.Var("x1");
var->GetMutable<p::Communicator>();
auto op = f::OpRegistry::CreateOp(*op_desc);
VLOG(1) << "invoke NCCLInitOp.";
op->Run(g_scope, cpu_place);
VLOG(1) << "NCCLInitOp finished.";
}
TEST_F(NCCLTester, ncclInitOp) {}
// ncclAllReduceOp with desc
TEST_F(NCCLTester, ncclAllReduceOp) {
......@@ -155,23 +148,25 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
for (size_t i = 0; i < gpu_list_.size(); ++i) {
dev_scopes.emplace_back(&g_scope_.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list_[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
ths[i].join();
}
// check results
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
float expected_result = 0.0;
for (int gpu_id : gpu_list_) {
expected_result = expected_result + GetGPUData(gpu_id);
}
for (size_t i = 0; i < dev_scopes.size(); ++i) {
p::CPUPlace cpu_place;
p::CUDAPlace gpu_place(gpu_list[i]);
p::CUDAPlace gpu_place(gpu_list_[i]);
auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -180,12 +175,12 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::CUDAPlace(gpu_list[i]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list_[i]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[i])->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[i])->stream());
for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
ASSERT_NEAR(ct[j], expected_result, 1e-5);
}
}
}
......@@ -204,22 +199,24 @@ TEST_F(NCCLTester, ncclReduceOp) {
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
for (size_t i = 0; i < gpu_list_.size(); ++i) {
dev_scopes.emplace_back(&g_scope_.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list_[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
ths[i].join();
}
// check results on
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
float expected_result = 0.0;
for (int gpu_id : gpu_list_) {
expected_result = expected_result + GetGPUData(gpu_id);
}
p::CPUPlace cpu_place;
p::CUDAPlace gpu_place(gpu_list[kRoot]);
p::CUDAPlace gpu_place(gpu_list_[kRoot]);
auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -229,12 +226,12 @@ TEST_F(NCCLTester, ncclReduceOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::CUDAPlace(gpu_list[kRoot]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list_[kRoot]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[kRoot])->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[kRoot])->stream());
for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
ASSERT_NEAR(ct[j], expected_result, 1e-5);
}
}
......@@ -252,23 +249,22 @@ TEST_F(NCCLTester, ncclBcastOp) {
std::vector<std::thread> ths;
for (size_t i = 0; i < gpu_list.size(); ++i) {
dev_scopes.emplace_back(&g_scope.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list[i],
for (size_t i = 0; i < gpu_list_.size(); ++i) {
dev_scopes.emplace_back(&g_scope_.NewScope());
std::thread th(&NCCLTester::PerThreadProgram<float>, this, gpu_list_[i],
*op2.get(), dev_scopes[i]);
ths.emplace_back(std::move(th));
}
for (size_t i = 0; i < gpu_list.size(); ++i) {
for (size_t i = 0; i < gpu_list_.size(); ++i) {
ths[i].join();
}
const int idx = 1;
// check results on
float result = kRoot;
float result = GetGPUData(kRoot);
p::CPUPlace cpu_place;
p::CUDAPlace gpu_place(gpu_list[idx]);
p::CUDAPlace gpu_place(gpu_list_[idx]);
auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -277,42 +273,11 @@ TEST_F(NCCLTester, ncclBcastOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::CUDAPlace(gpu_list[idx]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list_[idx]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[idx])->stream());
static_cast<p::CUDADeviceContext *>(dev_ctxs_[idx])->stream());
for (int64_t j = 0; j < f::product(kDims); ++j) {
ASSERT_NEAR(ct[j], result, 1e-5);
}
}
int main(int argc, char **argv) {
// FIXME(tonyyang-svail):
// Due to the driver issue on our CI, disable for now
return 0;
const int dev_count = p::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::CUDAPlace(i));
gpu_list.emplace_back(i);
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
// device context should be release before scope.
// otherwise driver will down.
return RUN_ALL_TESTS();
}
......@@ -111,7 +111,8 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
});
AddAttr<std::vector<float>>(
"max_sizes",
"(vector<float>) List of max sizes of generated prior boxes.");
"(vector<float>) List of max sizes of generated prior boxes.")
.SetDefault(std::vector<float>{});
AddAttr<std::vector<float>>(
"aspect_ratios",
"(vector<float>) List of aspect ratios of generated prior boxes.");
......
......@@ -97,9 +97,6 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
boxes->mutable_data<T>(ctx.GetPlace());
vars->mutable_data<T>(ctx.GetPlace());
T inv_img_width = 1.0 / img_width;
T inv_img_height = 1.0 / img_height;
auto e_boxes = framework::EigenTensor<T, 4>::From(*boxes);
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
......@@ -110,36 +107,30 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
// first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size;
box_width = box_height = min_size / 2.;
// xmin
e_boxes(h, w, idx, 0) = (center_x - box_width * 0.5) * inv_img_width;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
// ymin
e_boxes(h, w, idx, 1) =
(center_y - box_height * 0.5) * inv_img_height;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
// xmax
e_boxes(h, w, idx, 2) = (center_x + box_width * 0.5) * inv_img_width;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
// ymax
e_boxes(h, w, idx, 3) =
(center_y + box_height * 0.5) * inv_img_height;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
// second prior: aspect_ratio = 1,
// size = sqrt(min_size * max_size)
box_width = box_height = sqrt(min_size * max_size);
box_width = box_height = sqrt(min_size * max_size) / 2.;
// xmin
e_boxes(h, w, idx, 0) =
(center_x - box_width * 0.5) * inv_img_width;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
// ymin
e_boxes(h, w, idx, 1) =
(center_y - box_height * 0.5) * inv_img_height;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
// xmax
e_boxes(h, w, idx, 2) =
(center_x + box_width * 0.5) * inv_img_width;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
// ymax
e_boxes(h, w, idx, 3) =
(center_y + box_height * 0.5) * inv_img_height;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
......@@ -149,20 +140,16 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width = min_size * sqrt(ar);
box_height = min_size / sqrt(ar);
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
// xmin
e_boxes(h, w, idx, 0) =
(center_x - box_width * 0.5) * inv_img_width;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
// ymin
e_boxes(h, w, idx, 1) =
(center_y - box_height * 0.5) * inv_img_height;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
// xmax
e_boxes(h, w, idx, 2) =
(center_x + box_width * 0.5) * inv_img_width;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
// ymax
e_boxes(h, w, idx, 3) =
(center_y + box_height * 0.5) * inv_img_height;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
}
......
......@@ -18,22 +18,24 @@
namespace paddle {
namespace operators {
namespace reader {
class RecordIOFileReader : public framework::ReaderBase {
class RecordIOFileReader : public framework::FileReaderBase {
public:
explicit RecordIOFileReader(const std::string& filename)
: ReaderBase(),
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReaderBase(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
bool HasNext() const override { return scanner_.HasNext(); }
void ReInit() override { scanner_.Reset(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}
private:
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
......@@ -57,7 +59,8 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename));
out->Reset(
new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks)));
}
};
......@@ -83,3 +86,5 @@ namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker);
REGISTER_FILE_READER(recordio, reader::RecordIOFileReader);
......@@ -31,6 +31,11 @@ std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
return res;
}
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
static std::unordered_map<std::string, FileReaderCreator> regs;
return regs;
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
......@@ -49,6 +54,10 @@ FileReaderMakerBase::FileReaderMakerBase(
}
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(
!ctx->IsRuntime(),
"'FileReaderInferShape' should only be invoked during compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
......@@ -56,16 +65,14 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
......@@ -77,19 +84,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
void DecoratedReaderInferShape::operator()(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'DecoratedReaderInferShape' should only be invoked during "
"compile time.");
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
......
......@@ -21,6 +21,20 @@ namespace paddle {
namespace operators {
namespace reader {
using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;
std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
};
return 0;
}
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
......@@ -73,3 +87,15 @@ class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
paddle::operators::reader::DecoratedReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::DecoratedReaderInferVarType)
#define REGISTER_FILE_READER(_filetype, _reader) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_file_reader_##_filetype, \
"Must use REGISTER_FILE_READER in global namespace"); \
int TouchFileReader##_filetype() { return 0; } \
int _reg_file_reader_entry_##filetype = \
paddle::operators::reader::RegisterFileReader<_reader>(#_filetype)
#define USE_FILE_READER(filetype) \
extern int TouchFileReader##filetype(); \
static int _use_##filetype = TouchFileReader##filetype()
......@@ -48,7 +48,6 @@ nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
nv_test(nccl_test SRCS nccl_test.cu DEPS dynload_cuda gpu_info device_context)
cc_library(device_tracer SRCS device_tracer.cc DEPS profiler_proto ${GPU_CTX_DEPS})
cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer)
......
......@@ -127,6 +127,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
SetDeviceId(place_.device);
compute_capability = GetCUDAComputeCapability(place_.device);
multi_process = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
......@@ -162,6 +163,10 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaGetLastError());
}
int CUDADeviceContext::GetComputeCapability() const {
return compute_capability;
}
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
return multi_process * max_threads_per_mp;
}
......
......@@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return place in the device context. */
Place GetPlace() const override;
/*! \brief Return compute capability in the device context. */
int GetComputeCapability() const;
/*! \brief Return the max physical thread count in the device context */
int GetMaxPhysicalThreadCount() const;
......@@ -104,6 +107,7 @@ class CUDADeviceContext : public DeviceContext {
cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_;
int compute_capability;
int multi_process;
int max_threads_per_mp;
};
......
......@@ -33,6 +33,15 @@ int GetCUDADeviceCount() {
return count;
}
int GetCUDAComputeCapability(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
cudaDeviceProp device_prop;
PADDLE_ENFORCE(cudaGetDeviceProperties(&device_prop, id),
"cudaGetDeviceProperties failed in "
"paddle::platform::GetCUDAComputeCapability");
return device_prop.major * 10 + device_prop.minor;
}
int GetCUDAMultiProcessors(int id) {
PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
int count;
......
......@@ -30,6 +30,9 @@ const std::string kEnvFractionGpuMemoryToUse =
//! Get the total number of GPU devices in system.
int GetCUDADeviceCount();
//! Get the compute capability of the ith GPU (format: major * 10 + minor)
int GetCUDAComputeCapability(int i);
//! Get the MultiProcessors of the ith GPU.
int GetCUDAMultiProcessors(int i);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <memory>
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
static int dev_count = 0;
namespace paddle {
namespace platform {
TEST(NCCL, init) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
template <typename T>
struct PerThreadData {
thrust::device_vector<T> send_buff;
thrust::device_vector<T> recv_buff;
CUDADeviceContext dev_ctx;
T* SendBuff() { return thrust::raw_pointer_cast(send_buff.data()); }
T* RecvBuff() { return thrust::raw_pointer_cast(recv_buff.data()); }
PerThreadData(int gpu_id, size_t size) : dev_ctx(CUDAPlace(gpu_id)) {
send_buff.resize(size);
for (size_t i = 0; i < size; ++i) {
send_buff[i] = static_cast<T>(i);
}
recv_buff.resize(size);
}
};
static constexpr int ELEM_COUNT = 10000;
TEST(NCCL, all_reduce) {
std::vector<ncclComm_t> comms;
comms.resize(dev_count);
VLOG(1) << "Initializing ncclComm";
dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
VLOG(1) << "ncclComm initialized";
VLOG(1) << "Creating thread data";
std::vector<std::unique_ptr<PerThreadData<double>>> data;
data.reserve(dev_count);
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Creating thread data for device " << i;
SetDeviceId(i);
data.emplace_back(new PerThreadData<double>(i, ELEM_COUNT));
}
VLOG(1) << "Thread data created";
VLOG(1) << "Check send_buf data";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Check on device " << i;
SetDeviceId(i);
thrust::host_vector<double> tmp = data[i]->send_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
ASSERT_NEAR(static_cast<double>(j), tmp[j], 1e-5);
}
}
VLOG(1) << "Invoking ncclAllReduce";
dynload::ncclGroupStart();
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Invoking ncclAllReduce with device " << i;
SetDeviceId(i);
PADDLE_ENFORCE(dynload::ncclAllReduce(
data[i]->SendBuff(), data[i]->RecvBuff(), ELEM_COUNT, ncclDouble,
ncclSum, comms[i], data[i]->dev_ctx.stream()));
VLOG(1) << "Invoked ncclAllReduce for device " << i;
}
dynload::ncclGroupEnd();
VLOG(1) << "Invoked ncclAllReduce";
VLOG(1) << "Sync devices";
for (int i = 0; i < dev_count; ++i) {
VLOG(1) << "Sync device " << i;
SetDeviceId(i);
data[i]->dev_ctx.Wait();
}
VLOG(1) << "device synced";
for (int i = 0; i < dev_count; ++i) {
SetDeviceId(i);
VLOG(1) << "Checking vector on device " << i;
thrust::host_vector<double> tmp = data[i]->recv_buff;
for (size_t j = 0; j < tmp.size(); ++j) {
auto elem = static_cast<double>(j);
elem *= dev_count;
ASSERT_NEAR(tmp[j], elem, 1e-4);
}
}
for (int i = 0; i < dev_count; ++i) {
dynload::ncclCommDestroy(comms[i]);
}
}
} // namespace platform
} // namespace paddle
int main(int argc, char** argv) {
dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING)
<< "Cannot test multi-gpu nccl, because the CUDA device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::CUDAPlace(i));
}
VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -161,6 +161,8 @@ void BindBlockDesc(py::module &m) {
py::return_value_policy::reference)
.def("prepend_op", &BlockDesc::PrependOp,
py::return_value_policy::reference)
.def("insert_op", &BlockDesc::InsertOp,
py::return_value_policy::reference)
.def("remove_op", &BlockDesc::RemoveOp)
.def("var",
[](BlockDesc &self, py::bytes byte_name) {
......
......@@ -37,7 +37,7 @@ from distribute_transpiler_simple import SimpleDistributeTranspiler
from concurrency import (Go, make_channel, channel_send, channel_recv,
channel_close)
import clip
from memory_optimization_transpiler import memory_optimize
from memory_optimization_transpiler import memory_optimize, release_memory
import profiler
import unique_name
import recordio_writer
......@@ -64,6 +64,7 @@ __all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + [
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'memory_optimize',
'release_memory',
'profiler',
'unique_name',
'recordio_writer',
......
......@@ -457,7 +457,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
"Out": [_append_grad_suffix_(loss.name)]
}, {"shape": [1],
"value": 1.0,
"dtype": loss.dtype})
"dtype": loss.dtype,
"force_cpu": False})
root_block.desc.append_op().copy_from(op_desc)
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
......
......@@ -130,8 +130,13 @@ def detection_output(loc,
target_box=loc,
code_type='decode_center_size')
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
old_shape = scores.shape
scores = ops.reshape(x=scores, shape=(-1, old_shape[-1]))
scores = ops.softmax(x=scores)
scores = ops.reshape(x=scores, shape=old_shape)
scores = nn.transpose(scores, perm=[0, 2, 1])
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
helper.append_op(
type="multiclass_nms",
inputs={'Scores': scores,
......@@ -562,16 +567,16 @@ def multi_box_head(inputs,
base_size,
num_classes,
aspect_ratios,
min_ratio,
max_ratio,
min_ratio=None,
max_ratio=None,
min_sizes=None,
max_sizes=None,
steps=None,
step_w=None,
step_h=None,
offset=0.5,
variance=[0.1, 0.1, 0.1, 0.1],
flip=False,
variance=[0.1, 0.1, 0.2, 0.2],
flip=True,
clip=False,
kernel_size=1,
pad=0,
......@@ -614,7 +619,7 @@ def multi_box_head(inputs,
the inputs[i] will be automatically calculated. Default: None.
offset(float): Prior boxes center offset. Default: 0.5
variance(list|tuple): the variances to be encoded in prior boxes.
Default:[0.1, 0.1, 0.1, 0.1].
Default:[0.1, 0.1, 0.2, 0.2].
flip(bool): Whether to flip aspect ratios. Default:False.
clip(bool): Whether to clip out-of-boundary boxes. Default: False.
kernel_size(int): The kernel size of conv2d. Default: 1.
......@@ -668,6 +673,19 @@ def multi_box_head(inputs,
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
attrs = {
'min_sizes': min_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
}
if len(max_sizes) > 0 and max_sizes[0] > 0:
attrs['max_sizes'] = max_sizes
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
......@@ -676,17 +694,7 @@ def multi_box_head(inputs,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs={
'min_sizes': min_sizes,
'max_sizes': max_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
})
attrs=attrs, )
return box, var
def _reshape_with_axis_(input, axis=1):
......@@ -714,7 +722,7 @@ def multi_box_head(inputs,
if num_layer <= 2:
assert min_sizes is not None and max_sizes is not None
assert len(min_sizes) == num_layer and len(max_sizes) == num_layer
else:
elif min_sizes is None and max_sizes is None:
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
......@@ -759,9 +767,6 @@ def multi_box_head(inputs,
min_size = [min_size]
if not _is_list_or_tuple_(max_size):
max_size = [max_size]
if not (len(max_size) == len(min_size)):
raise ValueError(
'the length of max_size and min_size should be equal.')
aspect_ratio = []
if aspect_ratios is not None:
......@@ -779,7 +784,7 @@ def multi_box_head(inputs,
num_boxes = box.shape[2]
# get box_loc
# get loc
num_loc_output = num_boxes * 4
mbox_loc = nn.conv2d(
input=input,
......@@ -796,7 +801,7 @@ def multi_box_head(inputs,
mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten)
# get conf_loc
# get conf
num_conf_output = num_boxes * num_classes
conf_loc = nn.conv2d(
input=input,
......
......@@ -29,7 +29,10 @@ dtype_to_size = {
core.VarDesc.VarType.BOOL: 1
}
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
sub_block_ops = [
"while", "while_grad", "parallel_do", "parallel_do_grad",
"conditional_block", "conditional_block_grad"
]
PRINT_LOG = False
......@@ -122,36 +125,80 @@ class ControlFlowGraph(object):
else:
return block_desc.find_var_recursive(str(var_name))
def memory_optimize(self):
def check_var_validity(block_desc, x, is_forward):
if str(x) == "@EMPTY@":
return False
if not self._has_var(block_desc, x, is_forward):
return False
if self._find_var(block_desc, x, is_forward).persistable():
return False
if self._find_var(
block_desc, x,
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
return False
if x in self._skip_opt:
return False
if not self._find_var(block_desc, x, is_forward).shape():
return False
return True
def _check_var_validity(self, block_desc, x, is_forward):
if str(x) == "@EMPTY@":
return False
if not self._has_var(block_desc, x, is_forward):
return False
if self._find_var(block_desc, x, is_forward).persistable():
return False
if self._find_var(block_desc, x,
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
return False
if x in self._skip_opt:
return False
if not self._find_var(block_desc, x, is_forward).shape():
return False
return True
self._build_graph()
def _update_skip_opt_set(self):
for i in range(self.op_size):
op = self._ops[i]
if op.type() == "fill_constant" and op.attr("force_cpu") == True:
self._skip_opt.update(op.output_arg_names())
def release_memory(self):
self._dataflow_analyze()
self._update_skip_opt_set()
fwd_id = 0
bwd_id = 0
for i in range(self.op_size):
op = self._ops[i]
if op.type() in sub_block_ops:
continue
block_desc = op.block()
is_forward = i < self._forward_num
in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
can_optimize = filter(
lambda x: self._check_var_validity(block_desc, x, is_forward),
in_diff)
if can_optimize:
index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1
delete_op = block_desc.insert_op(index)
delete_op.set_type("delete_var")
delete_op.set_input("X", can_optimize)
if is_forward:
fwd_id += 1
else:
bwd_id += 1
def memory_optimize(self, level=0):
def compare_shape(x_shape, cache_shape, opt_level):
if opt_level == 0:
return x_shape == cache_shape
if opt_level == 1:
if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
return False
x_size = abs(reduce(lambda x, y: x * y, x_shape))
cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
if x_size <= cache_size:
return True
return False
self._dataflow_analyze()
self._update_skip_opt_set()
self.pool = []
for i in range(self.op_size):
op = self._ops[i]
if op.type() in sub_block_ops:
continue
block_desc = op.block()
self.current_block_desc = block_desc
is_forward = i < self._forward_num
if self.pool:
defs_can_optimize = filter(
lambda x: check_var_validity(block_desc, x, is_forward),
lambda x: self._check_var_validity(block_desc, x, is_forward),
self._defs[i])
out_pair = [
(x, self._find_var(block_desc, x, is_forward).shape())
......@@ -164,7 +211,7 @@ class ControlFlowGraph(object):
for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if x_shape == cache_shape:
if compare_shape(x_shape, cache_shape, level):
if self._has_var(block_desc, cache_var, is_forward):
x_dtype = self._find_var(block_desc, x,
is_forward).dtype()
......@@ -196,7 +243,7 @@ class ControlFlowGraph(object):
in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
can_optimize = filter(
lambda x: check_var_validity(block_desc, x, is_forward),
lambda x: self._check_var_validity(block_desc, x, is_forward),
in_diff)
if can_optimize:
for var_name in can_optimize:
......@@ -270,7 +317,8 @@ def _get_cfgs(input_program):
([block_desc.op(i) for i in range(op_size)], op_size, set()))
sub_block_pair = [("while", "while_grad"), ("parallel_do",
"parallel_do_grad")]
"parallel_do_grad"),
("conditional_block", "conditional_block_grad")]
ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
......@@ -281,9 +329,15 @@ def _get_cfgs(input_program):
return cfgs
def memory_optimize(input_program, print_log=False):
def memory_optimize(input_program, print_log=False, level=0):
global PRINT_LOG
PRINT_LOG = print_log
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.memory_optimize()
cfg.memory_optimize(level)
def release_memory(input_program):
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.release_memory()
......@@ -92,7 +92,10 @@ class Optimizer(object):
# create learning rate variable for every parameter
param = param_and_grad[0]
param_lr = param.optimize_attr['learning_rate']
return self.global_learning_rate() * param_lr
if param_lr == 1.0:
return self.global_learning_rate()
else:
return self.global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters
......
......@@ -50,6 +50,7 @@ sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
# fluid.release_memory(fluid.default_main_program())
BATCH_SIZE = 200
......@@ -69,8 +70,6 @@ exe.run(fluid.default_startup_program())
PASS_NUM = 100
for pass_id in range(PASS_NUM):
fluid.io.save_persistables(exe, "./fit_a_line.model/")
fluid.io.load_persistables(exe, "./fit_a_line.model/")
for data in train_reader():
avg_loss_value, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
......
......@@ -125,9 +125,10 @@ opts = optimizer.minimize(avg_cost)
batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
fluid.memory_optimize(fluid.default_main_program())
# fluid.memory_optimize(fluid.default_main_program(), level=0)
fluid.release_memory(fluid.default_main_program())
BATCH_SIZE = 128
BATCH_SIZE = 16
PASS_NUM = 1
# fix the order of training data
......@@ -159,8 +160,7 @@ for pass_id in range(PASS_NUM):
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc))
# this model is slow, so if we can train two mini batch, we think it works properly.
if i > 2:
if i > 0:
exit(0)
if math.isnan(float(loss)):
sys.exit("got NaN loss, training failed.")
......
......@@ -105,7 +105,8 @@ def main():
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program())
# fluid.memory_optimize(fluid.default_main_program())
fluid.release_memory(fluid.default_main_program())
# fix the order of training data
train_data = paddle.batch(
......
......@@ -166,8 +166,6 @@ class TestDetectionMAPOp(OpTest):
elif not difficult:
label_count[label] += 1
true_pos = collections.defaultdict(list)
false_pos = collections.defaultdict(list)
for (label, score, tp, fp) in tf_pos:
true_pos[label].append([score, tp])
false_pos[label].append([score, fp])
......
......@@ -98,6 +98,9 @@ class TestLearningRateDecay(unittest.TestCase):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.memory_optimize(fluid.default_main_program())
for step in range(10):
lr_val, = exe.run(fluid.default_main_program(),
feed={},
......
......@@ -21,31 +21,43 @@ from paddle.fluid.backward import append_backward
class TestOptimizer(unittest.TestCase):
def test_sgd_optimizer(self):
init_program = framework.Program()
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01)
opts, _ = sgd_optimizer.minimize(mean_out, init_program)
def check_sgd_optimizer(optimizer_attr):
init_program = framework.Program()
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr=optimizer_attr)
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mean_out = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="mean.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
block.append_op(
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.01)
opts, _ = sgd_optimizer.minimize(mean_out, init_program)
return opts
opts = check_sgd_optimizer({'learning_rate': 1.1})
self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "sgd"])
opts = check_sgd_optimizer({'learning_rate': 1.0})
self.assertEqual(len(opts), 1)
self.assertEqual([op.type for op in opts], ["sgd"])
class TestMomentumOptimizer(unittest.TestCase):
class MockMomentum(optimizer.MomentumOptimizer):
......@@ -60,7 +72,11 @@ class TestMomentumOptimizer(unittest.TestCase):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
......@@ -110,7 +126,11 @@ class TestMomentumOptimizer(unittest.TestCase):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
......@@ -169,7 +189,11 @@ class TestAdagradOptimizer(unittest.TestCase):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
......@@ -229,7 +253,11 @@ class TestAdamOptimizer(unittest.TestCase):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
......@@ -292,7 +320,11 @@ class TestAdamaxOptimizer(unittest.TestCase):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
......@@ -352,7 +384,11 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
optimize_attr={'learning_rate': 1.1})
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册