提交 4cba5500 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_lr_decay

...@@ -136,6 +136,12 @@ else() ...@@ -136,6 +136,12 @@ else()
set(THIRD_PARTY_BUILD_TYPE Release) set(THIRD_PARTY_BUILD_TYPE Release)
endif() endif()
if(WITH_MKL)
option(MKL_SPLIT_GEMM "PaddlePaddle MKL gemm would split to small ones" OFF)
if (MKL_SPLIT_GEMM)
add_definitions(-DPADDLE_MKL_SPLIT_GEMM)
endif()
endif()
set(WITH_MKLML ${WITH_MKL}) set(WITH_MKLML ${WITH_MKL})
if (NOT DEFINED WITH_MKLDNN) if (NOT DEFINED WITH_MKLDNN)
if (WITH_MKL AND AVX2_FOUND) if (WITH_MKL AND AVX2_FOUND)
......
#!/bin/bash
set -e set -e
function train() { function train() {
......
#!/bin/bash
set -e set -e
function clock_to_seconds() { function clock_to_seconds() {
......
#!/bin/bash
set -e set -e
function train() { function train() {
......
#!/bin/bash
set -e set -e
function clock_to_seconds() { function clock_to_seconds() {
......
#!/bin/bash
set -e set -e
function train() { function train() {
......
#!/bin/bash
set -e set -e
function train() { function train() {
......
#!/bin/bash
set -e set -e
function test() { function test() {
......
#!/bin/bash
set -e set -e
function test() { function test() {
......
#!/bin/bash
set -e set -e
function test() { function test() {
......
#!/bin/bash
set -e set -e
function test() { function test() {
......
...@@ -180,13 +180,13 @@ paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, default ...@@ -180,13 +180,13 @@ paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, default
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, None, 1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.shuffle ArgSpec(args=['reader', 'buffer_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.shuffle ArgSpec(args=['reader', 'buffer_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.batch ArgSpec(args=['reader', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.batch ArgSpec(args=['reader', 'batch_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.double_buffer ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.double_buffer ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.random_data_generator ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.layers.random_data_generator ArgSpec(args=['low', 'high', 'shapes', 'lod_levels', 'for_parallel'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.layers.py_reader ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.py_reader ArgSpec(args=['capacity', 'shapes', 'dtypes', 'lod_levels', 'name', 'use_double_buffer'], varargs=None, keywords=None, defaults=(None, None, True))
paddle.fluid.layers.Preprocessor.__init__ ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.Preprocessor.__init__ ArgSpec(args=['self', 'reader', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.Preprocessor.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.layers.Preprocessor.block ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.layers.Preprocessor.inputs ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Preprocessor.inputs ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
...@@ -209,9 +209,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs= ...@@ -209,9 +209,6 @@ paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.split_lod_tensor ArgSpec(args=['input', 'mask', 'level'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.split_lod_tensor ArgSpec(args=['input', 'mask', 'level'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.merge_lod_tensor ArgSpec(args=['in_true', 'in_false', 'x', 'mask', 'level'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.merge_lod_tensor ArgSpec(args=['in_true', 'in_false', 'x', 'mask', 'level'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.BlockGuard.__init__ ArgSpec(args=['self', 'main_program'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.BlockGuardWithCompletion.__init__ ArgSpec(args=['self', 'rnn'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.WhileGuard.__init__ ArgSpec(args=['self', 'while_op'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.While.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.While.complete ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
......
cc_library(var_handle SRCS var_handle.cc DEPS place) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
......
...@@ -333,7 +333,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -333,7 +333,7 @@ std::unique_ptr<Graph> MultiDevSSAGraphBuilder::Apply(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
return std::move(graph); return graph;
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
......
...@@ -35,14 +35,16 @@ struct ReduceLoDTensor { ...@@ -35,14 +35,16 @@ struct ReduceLoDTensor {
PADDLE_ENFORCE(!src_tensors_.empty()); PADDLE_ENFORCE(!src_tensors_.empty());
auto &t0 = *src_tensors_[0]; auto &t0 = *src_tensors_[0];
PADDLE_ENFORCE_NE(t0.numel(), 0); PADDLE_ENFORCE_NE(t0.numel(), 0);
dst_tensor_.Resize(t0.dims()); dst_tensor_.Resize(t0.dims());
T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace()); T *dst = dst_tensor_.mutable_data<T>(platform::CPUPlace());
if (dst != t0.data<T>()) {
std::copy(t0.data<T>(), t0.data<T>() + t0.numel(), dst);
}
for (size_t i = 1; i < src_tensors_.size(); ++i) { for (size_t i = 0; i < src_tensors_.size(); ++i) {
auto &t = *src_tensors_[i]; auto &t = *src_tensors_[i];
if (dst == t.data<T>()) {
continue;
}
PADDLE_ENFORCE_EQ(t.dims(), t0.dims()); PADDLE_ENFORCE_EQ(t.dims(), t0.dims());
PADDLE_ENFORCE_EQ(t.type(), t0.type()); PADDLE_ENFORCE_EQ(t.type(), t0.type());
std::transform(t.data<T>(), t.data<T>() + t.numel(), dst, dst, std::transform(t.data<T>(), t.data<T>() + t.numel(), dst, dst,
......
...@@ -31,7 +31,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder { ...@@ -31,7 +31,7 @@ class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override { std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
PADDLE_ENFORCE(IsValidGraph(new_graph.get())); PADDLE_ENFORCE(IsValidGraph(new_graph.get()));
return std::move(new_graph); return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -53,7 +53,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { ...@@ -53,7 +53,7 @@ class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override { std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const override {
auto new_graph = builder_->Apply(std::move(graph)); auto new_graph = builder_->Apply(std::move(graph));
printer_->Print(*new_graph, stream_ref_); printer_->Print(*new_graph, stream_ref_);
return std::move(new_graph); return new_graph;
} }
int GetVarDeviceID(const std::string& var_name) const override { int GetVarDeviceID(const std::string& var_name) const override {
......
...@@ -171,7 +171,12 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -171,7 +171,12 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars.at(var_name); auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable.(Perhaps the main_program "
"is not set to ParallelExecutor)");
auto &vars = fetched_var_it->second;
temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); temp_nodes->emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i, auto *op = new FetchOpHandle(temp_nodes->back().get(), fetch_data, i,
......
cc_library(graph SRCS graph.cc DEPS node)
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node)
cc_library(pass SRCS pass.cc DEPS graph node) cc_library(pass SRCS pass.cc DEPS graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry)
...@@ -21,6 +21,7 @@ namespace framework { ...@@ -21,6 +21,7 @@ namespace framework {
// NOTE(paddle-dev): This graph contains circle. // NOTE(paddle-dev): This graph contains circle.
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars; std::unordered_map<std::string, VarDesc *> all_vars;
for (auto *var : program.Block(0).AllVars()) { for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var); all_vars.emplace(var->Name(), var);
......
...@@ -312,19 +312,22 @@ void WriteToRecordIO(recordio::Writer *writer, ...@@ -312,19 +312,22 @@ void WriteToRecordIO(recordio::Writer *writer,
writer->Write(buffer.str()); writer->Write(buffer.str());
} }
std::vector<LoDTensor> ReadFromRecordIO( bool ReadFromRecordIO(recordio::Scanner *scanner,
recordio::Scanner *scanner, const platform::DeviceContext &dev_ctx) { const platform::DeviceContext &dev_ctx,
std::vector<LoDTensor> result; std::vector<LoDTensor> *result_ptr) {
if (scanner->HasNext()) { if (!scanner->HasNext()) {
return false;
}
std::istringstream sin(scanner->Next()); std::istringstream sin(scanner->Next());
uint32_t sz; uint32_t sz;
sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t)); sin.read(reinterpret_cast<char *>(&sz), sizeof(uint32_t));
auto &result = *result_ptr;
result.resize(sz); result.resize(sz);
for (uint32_t i = 0; i < sz; ++i) { for (uint32_t i = 0; i < sz; ++i) {
DeserializeFromStream(sin, &result[i], dev_ctx); DeserializeFromStream(sin, &result[i], dev_ctx);
} }
}
return result; return true;
} }
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
......
...@@ -223,8 +223,9 @@ extern void WriteToRecordIO(recordio::Writer* writer, ...@@ -223,8 +223,9 @@ extern void WriteToRecordIO(recordio::Writer* writer,
const std::vector<LoDTensor>& tensor, const std::vector<LoDTensor>& tensor,
const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx);
extern std::vector<LoDTensor> ReadFromRecordIO( extern bool ReadFromRecordIO(recordio::Scanner* scanner,
recordio::Scanner* scanner, const platform::DeviceContext& dev_ctx); const platform::DeviceContext& dev_ctx,
std::vector<LoDTensor>* result_ptr);
/* /*
* Convert between length-based LoD and offset-based LoD. * Convert between length-based LoD and offset-based LoD.
......
...@@ -301,11 +301,12 @@ static void TestRecordIO() { ...@@ -301,11 +301,12 @@ static void TestRecordIO() {
{ {
std::unique_ptr<std::istream> stream_ptr(stream); std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr)); recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(&scanner, ctx); std::vector<framework::LoDTensor> tensors;
ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors));
ASSERT_EQ(tensors.size(), static_cast<size_t>(2)); ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(&scanner, ctx); ASSERT_TRUE(ReadFromRecordIO(&scanner, ctx, &tensors));
ASSERT_EQ(tensors.size(), static_cast<size_t>(2)); ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
......
...@@ -67,7 +67,8 @@ void ReaderBase::Start() { ...@@ -67,7 +67,8 @@ void ReaderBase::Start() {
} }
} }
ReaderBase::~ReaderBase() { Shutdown(); } ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum ReaderStatus { kRunning, kStopped };
class ReaderBase { class ReaderBase {
public: public:
virtual void ReadNext(std::vector<LoDTensor>* out); virtual void ReadNext(std::vector<LoDTensor>* out);
...@@ -48,6 +46,8 @@ class ReaderBase { ...@@ -48,6 +46,8 @@ class ReaderBase {
virtual void StartImpl() {} virtual void StartImpl() {}
enum ReaderStatus { kRunning, kStopped };
ReaderStatus status_{kRunning}; ReaderStatus status_{kRunning};
mutable std::mutex mu_; mutable std::mutex mu_;
...@@ -74,6 +74,8 @@ class DecoratedReader : public ReaderBase, ...@@ -74,6 +74,8 @@ class DecoratedReader : public ReaderBase,
reader_->InsertDecoratedReader(shared_from_this()); reader_->InsertDecoratedReader(shared_from_this());
} }
~DecoratedReader();
protected: protected:
void ShutdownImpl() override { reader_->Shutdown(); } void ShutdownImpl() override { reader_->Shutdown(); }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -261,7 +262,8 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
os.write(out.data(), size); os.write(out.data(), size);
} }
{ // the 3rd field, tensor data { // the 3rd field, tensor data
uint64_t size = tensor.memory_size(); uint64_t size = tensor.numel() * framework::SizeOfType(tensor.type());
auto* data_ptr = tensor.data<void>(); auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(), PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor"); "Index overflow when writing tensor");
...@@ -331,6 +333,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -331,6 +333,9 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size =
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;
...@@ -338,7 +343,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -338,7 +343,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
framework::VisitDataType( framework::VisitDataType(
desc.data_type(), desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size()); is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace(); auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
#else #else
...@@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -348,7 +353,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
framework::VisitDataType( framework::VisitDataType(
desc.data_type(), desc.data_type(),
DeserializedDataFunctor(&buf, tensor, ctx.GetPlace())); DeserializedDataFunctor(&buf, tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), tensor->memory_size()); is.read(static_cast<char*>(buf), size);
} }
} }
} }
......
...@@ -38,4 +38,6 @@ if(WITH_TESTING) ...@@ -38,4 +38,6 @@ if(WITH_TESTING)
# both tests/book and analysis depends the models that generated by python/paddle/fluid/tests/book # both tests/book and analysis depends the models that generated by python/paddle/fluid/tests/book
add_subdirectory(tests/book) add_subdirectory(tests/book)
endif() endif()
add_subdirectory(api) if(NOT APPLE)
add_subdirectory(api)
endif()
...@@ -22,8 +22,6 @@ ...@@ -22,8 +22,6 @@
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
namespace paddle { namespace paddle {
namespace inference {
namespace analysis {
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
"Enable subgraph to TensorRT engine for acceleration"); "Enable subgraph to TensorRT engine for acceleration");
...@@ -31,6 +29,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false, ...@@ -31,6 +29,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
DEFINE_string(inference_analysis_graphviz_log_root, "./", DEFINE_string(inference_analysis_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs."); "Graphviz debuger for data flow graphs.");
namespace inference {
namespace analysis {
class DfgPassManagerImpl final : public DfgPassManager { class DfgPassManagerImpl final : public DfgPassManager {
public: public:
DfgPassManagerImpl() { DfgPassManagerImpl() {
......
...@@ -45,14 +45,15 @@ limitations under the License. */ ...@@ -45,14 +45,15 @@ limitations under the License. */
#include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/pass_manager.h"
namespace paddle { namespace paddle {
namespace inference {
namespace analysis {
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this // TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available. // flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine); DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root); DECLARE_string(inference_analysis_graphviz_log_root);
namespace inference {
namespace analysis {
class Analyzer : public OrderedRegistry<PassManager> { class Analyzer : public OrderedRegistry<PassManager> {
public: public:
// Register all the pass-managers. // Register all the pass-managers.
......
...@@ -13,13 +13,21 @@ ...@@ -13,13 +13,21 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <google/protobuf/text_format.h>
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST_F(DFG_Tester, main) { TEST_F(DFG_Tester, analysis_without_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false;
Analyzer analyser;
analyser.Run(&argument);
}
TEST_F(DFG_Tester, analysis_with_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true;
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
......
...@@ -222,10 +222,19 @@ Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() { ...@@ -222,10 +222,19 @@ Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
return stack_.top(); return stack_.top();
} }
inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
return node.inlinks.size() == n;
}
GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator( GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
const std::vector<Node *> &source) { const std::vector<Node *> &source) {
PADDLE_ENFORCE(!source.empty(), PADDLE_ENFORCE(!source.empty(),
"Start points of topological sorting should not be empty!"); "Start points of topological sorting should not be empty!");
// CHECK all the inputs' in-degree is 0
for (auto *node : source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
}
std::unordered_set<Node *> visited; std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()}; std::unordered_set<Node *> to_visit{source.begin(), source.end()};
...@@ -233,6 +242,11 @@ GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator( ...@@ -233,6 +242,11 @@ GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
while (!to_visit.empty()) { while (!to_visit.empty()) {
std::vector<Node *> queue(to_visit.begin(), to_visit.end()); std::vector<Node *> queue(to_visit.begin(), to_visit.end());
for (auto *p : queue) { for (auto *p : queue) {
if (p->deleted()) {
visited.insert(p);
to_visit.erase(p);
continue;
}
inlink_visited.clear(); inlink_visited.clear();
std::copy_if(p->inlinks.begin(), p->inlinks.end(), std::copy_if(p->inlinks.begin(), p->inlinks.end(),
...@@ -292,6 +306,37 @@ Node *GraphTraits<DataFlowGraph>::NodesTSIterator::operator->() { ...@@ -292,6 +306,37 @@ Node *GraphTraits<DataFlowGraph>::NodesTSIterator::operator->() {
return sorted_[cursor_]; return sorted_[cursor_];
} }
std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
std::unordered_set<Node *> inputs;
std::unordered_set<Node *> outputs;
// Input a Value, check whether its inlink is in the subgraph.
auto inlink_in_subgraph = [&](Node *n) {
for (auto *in : n->inlinks) {
if (nodes.count(in)) return true;
}
return false;
};
for (auto &node : graph) {
for (auto *in : node->inlinks) {
// The Value that is written by nodes inside a sub-graph shouldn't be the
// input of the sub-graph.
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
!inlink_in_subgraph(in)) {
inputs.insert(in);
}
}
for (auto *out : node->outlinks) {
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
outputs.insert(out);
}
}
}
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
std::vector<Node *>(outputs.begin(), outputs.end()));
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -133,7 +133,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -133,7 +133,7 @@ struct GraphTraits<DataFlowGraph> {
private: private:
std::vector<Node *> sorted_; std::vector<Node *> sorted_;
int cursor_{0}; size_t cursor_{0};
}; };
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {} explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
...@@ -173,36 +173,8 @@ struct GraphTraits<DataFlowGraph> { ...@@ -173,36 +173,8 @@ struct GraphTraits<DataFlowGraph> {
// Extract the inputs and outputs of a graph. The inputs and outputs of a // Extract the inputs and outputs of a graph. The inputs and outputs of a
// sub-graph is the inputs nodes and output nodes that doesn't inside the // sub-graph is the inputs nodes and output nodes that doesn't inside the
// sub-graph. // sub-graph.
static std::pair<std::vector<Node *>, std::vector<Node *>> std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph);
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
std::unordered_set<Node *> inputs;
std::unordered_set<Node *> outputs;
// Input a Value, check whether its inlink is in the subgraph.
auto inlink_in_subgraph = [&](Node *n) {
for (auto *in : n->inlinks) {
if (nodes.count(in)) return true;
}
return false;
};
for (auto &node : graph) {
for (auto *in : node->inlinks) {
// The Value that is written by nodes inside a sub-graph shouldn't be the
// input of the sub-graph.
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
!inlink_in_subgraph(in)) {
inputs.insert(in);
}
}
for (auto *out : node->outlinks) {
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
outputs.insert(out);
}
}
}
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
std::vector<Node *>(outputs.begin(), outputs.end()));
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -22,14 +22,18 @@ ...@@ -22,14 +22,18 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size");
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
namespace analysis { namespace analysis {
using framework::proto::ProgramDesc; using framework::proto::ProgramDesc;
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>>& nodes); const std::vector<std::unique_ptr<Node>> &nodes);
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) { bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument) ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
PADDLE_ENFORCE(!argument->transformed_program_desc); PADDLE_ENFORCE(!argument->transformed_program_desc);
...@@ -47,32 +51,34 @@ bool DataFlowGraphToFluidPass::Initialize(Argument* argument) { ...@@ -47,32 +51,34 @@ bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
bool DataFlowGraphToFluidPass::Finalize() { return true; } bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) { void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
auto traits = GraphTraits<DataFlowGraph>(graph); LOG(INFO) << "graph.inputs " << graph->inputs.size();
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) { for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
if (it->deleted()) continue; if (node.deleted()) continue;
switch (it->type()) { switch (node.type()) {
case Node::Type::kFunction: { case Node::Type::kFunction: {
LOG(INFO) << "add function " << it->repr(); LOG(INFO) << "add function " << node.repr();
AddFluidOp(&(*it)); AddFluidOp(&node);
} break; } break;
case Node::Type::kFunctionBlock: { case Node::Type::kFunctionBlock: {
LOG(INFO) << "add engine op " << it->repr() << " , " LOG(INFO) << "add engine op " << node.repr() << " , "
<< static_cast<FunctionBlock*>(&(*it))->subgraph.size(); << static_cast<FunctionBlock *>(&node)->subgraph.size();
AddEngineOp(&(*it)); AddEngineOp(&node);
} break; } break;
default: default:
continue; continue;
} }
} }
PADDLE_ENFORCE(argument_->transformed_program_desc.get());
} }
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc()); auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc());
// currently only the main block is analyzed. // currently only the main block is analyzed.
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops(); auto *op = main_block->add_ops();
*op = *ori_op; // copy the attributes, by default, these will not be changed *op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase. // by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt // The inputs and outputs of the existing ops are not changed by tensorrt
...@@ -80,43 +86,42 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node* node) { ...@@ -80,43 +86,42 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
// NOTE It might be changed by other passes in the long run. // NOTE It might be changed by other passes in the long run.
} }
void CreateTrtEngineOp(Node* node, const DataFlowGraph& graph, void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
const framework::proto::BlockDesc& block) { const framework::proto::BlockDesc &block) {
static int counter{0}; static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock()); PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc; framework::OpDesc desc;
auto* func = static_cast<FunctionBlock*>(node); auto *func = static_cast<FunctionBlock *>(node);
// collect inputs // collect inputs
std::vector<std::string> io; std::vector<std::string> io;
for (auto* x : func->inlinks) { for (auto *x : func->inlinks) {
io.push_back(x->name()); io.push_back(x->name());
} }
desc.SetInput("Xs", io); desc.SetInput("Xs", io);
// collect outputs // collect outputs
io.clear(); io.clear();
for (auto* x : func->outlinks) { for (auto *x : func->outlinks) {
io.push_back(x->name()); io.push_back(x->name());
} }
desc.SetOutput("Ys", io); desc.SetOutput("Ys", io);
desc.SetType("tensorrt_engine"); desc.SetType("tensorrt_engine");
PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
// Set attrs // Set attrs
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString()); SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
SetAttr(desc.Proto(), "engine_unique_key", SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++));
"trt-" + std::to_string(counter++)); SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
SetAttr(desc.Proto(), "max_batch", 100); // TODO(Superjomn) add config latter SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
SetAttr(desc.Proto(), "max_workspace",
1024); // TODO(Superjomn) add config latter
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
node->SetPbMsg(desc.Proto()->SerializeAsString()); node->SetPbMsg(desc.Proto()->SerializeAsString());
} }
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>>& nodes) { const std::vector<std::unique_ptr<Node>> &nodes) {
std::vector<std::string> parameters; std::vector<std::string> parameters;
for (const auto& node : nodes) { for (const auto &node : nodes) {
if (!node->IsValue()) continue; if (!node->IsValue()) continue;
PADDLE_ENFORCE(!node->pb_msg().empty(), "pb_msg should be set first"); PADDLE_ENFORCE(!node->pb_msg().empty(), "pb_msg should be set first");
framework::proto::VarDesc var; framework::proto::VarDesc var;
...@@ -128,21 +133,30 @@ std::vector<std::string> ExtractParameters( ...@@ -128,21 +133,30 @@ std::vector<std::string> ExtractParameters(
return parameters; return parameters;
} }
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) { void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
// TODO(Superjomn) Here need to expose some arguments for default setting. // TODO(Superjomn) Here need to expose some arguments for default setting.
PADDLE_ENFORCE(node->IsFunctionBlock()); PADDLE_ENFORCE(node->IsFunctionBlock());
auto* block_node = static_cast<FunctionBlock*>(node); auto *block_node = static_cast<FunctionBlock *>(node);
framework::proto::BlockDesc proto; framework::proto::BlockDesc proto;
framework::BlockDesc block_desc(nullptr, &proto); framework::BlockDesc block_desc(nullptr, &proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
LOG(INFO) << "origin variable size: "
<< argument_->origin_program_desc->blocks(0).vars().size();
LOG(INFO) << "transformed variable size: "
<< block_desc.Proto()->vars().size();
// copy ops. // copy ops.
for (auto* node : block_node->subgraph) { for (auto *node : block_node->subgraph) {
auto* op = block_desc.AppendOp(); auto *op = block_desc.AppendOp();
PADDLE_ENFORCE(!node->pb_msg().empty()); PADDLE_ENFORCE(!node->pb_msg().empty());
op->Proto()->ParseFromString(node->pb_msg()); op->Proto()->ParseFromString(node->pb_msg());
} }
*block_desc.Proto()->mutable_vars() =
argument_->origin_program_desc->blocks(0).vars();
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto()); CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops(); auto *op = main_block->add_ops();
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block"); PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
op->ParseFromString(node->pb_msg()); op->ParseFromString(node->pb_msg());
} }
...@@ -151,7 +165,7 @@ namespace { ...@@ -151,7 +165,7 @@ namespace {
class DFG_DebuggerPass : public DFG_GraphvizDrawPass { class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
public: public:
using Config = DFG_GraphvizDrawPass::Config; using Config = DFG_GraphvizDrawPass::Config;
explicit DFG_DebuggerPass(const Config& config) explicit DFG_DebuggerPass(const Config &config)
: DFG_GraphvizDrawPass(config) {} : DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "dfg-to-fluid-debuger-pass"; } std::string repr() const override { return "dfg-to-fluid-debuger-pass"; }
...@@ -160,7 +174,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -160,7 +174,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
}; };
} // namespace } // namespace
Pass* DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const { Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, FLAGS_inference_analysis_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger")); "data_flow_graph_to_fluid_graphviz_debugger"));
......
...@@ -26,6 +26,10 @@ ...@@ -26,6 +26,10 @@
namespace paddle { namespace paddle {
namespace inference { namespace inference {
DECLARE_int32(tensorrt_max_batchsize);
DECLARE_int32(tensorrt_workspace_size);
namespace analysis { namespace analysis {
class DataFlowGraphToFluidPass final : public DataFlowGraphPass { class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
public: public:
......
...@@ -40,7 +40,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) { ...@@ -40,7 +40,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
no++; no++;
} }
// DFG is sensitive to ProgramDesc, be careful to change the existing models. // DFG is sensitive to ProgramDesc, be careful to change the existing models.
ASSERT_EQ(no, 82); ASSERT_EQ(no, 83);
} }
} // namespace analysis } // namespace analysis
......
...@@ -28,7 +28,6 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) { ...@@ -28,7 +28,6 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc); ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc);
PADDLE_ENFORCE(argument); PADDLE_ENFORCE(argument);
if (!argument->main_dfg) { if (!argument->main_dfg) {
LOG(INFO) << "Init DFG";
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
desc_ = argument->origin_program_desc.get(); desc_ = argument->origin_program_desc.get();
...@@ -51,6 +50,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -51,6 +50,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
v->SetPbMsg(var.SerializeAsString()); v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id(); var2id[var.name()] = v->id();
} }
for (int i = 0; i < main_block.ops_size(); i++) { for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i); const auto &op = main_block.ops(i);
auto *o = graph->nodes.Create(Node::Type::kFunction); auto *o = graph->nodes.Create(Node::Type::kFunction);
...@@ -62,19 +62,31 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { ...@@ -62,19 +62,31 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
o->SetPbMsg(op.SerializeAsString()); o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs // set inputs and outputs
// TODO(Superjomn) make sure the InputNames is the real variable name. std::unordered_set<Node *> inlinks;
for (int j = 0; j < op.inputs_size(); j++) { for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j); auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) { for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k))); auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o); in->outlinks.push_back(o);
o->inlinks.push_back(in); o->inlinks.push_back(in);
inlinks.insert(in);
} }
} }
for (int j = 0; j < op.outputs_size(); j++) { for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j); auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) { for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]); auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
if (inlinks.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = graph->nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] = out_alias->id(); // update a -> a0
LOG(INFO) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias;
}
out->inlinks.push_back(o); out->inlinks.push_back(o);
o->outlinks.push_back(out); o->outlinks.push_back(out);
} }
......
...@@ -24,12 +24,12 @@ namespace analysis { ...@@ -24,12 +24,12 @@ namespace analysis {
TEST_F(DFG_Tester, Init) { TEST_F(DFG_Tester, Init) {
FluidToDataFlowGraphPass pass; FluidToDataFlowGraphPass pass;
pass.Initialize(&argument); pass.Initialize(&argument);
DataFlowGraph graph; pass.Run(argument.main_dfg.get());
pass.Run(&graph);
// Analysis is sensitive to ProgramDesc, careful to change the original model. // Analysis is sensitive to ProgramDesc, careful to change the original model.
ASSERT_EQ(graph.nodes.size(), 37UL); ASSERT_EQ(argument.main_dfg->nodes.size(), 38UL);
pass.Finalize(); pass.Finalize();
LOG(INFO) << '\n' << graph.DotString(); ASSERT_FALSE(argument.main_dfg->DotString().empty());
EXPECT_FALSE(argument.main_dfg->inputs.empty());
} }
} // namespace analysis } // namespace analysis
......
...@@ -25,6 +25,9 @@ TensorRTSubGraphPass::TensorRTSubGraphPass( ...@@ -25,6 +25,9 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) { void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_)(); SubGraphFuse(graph, node_inside_subgraph_teller_)();
VLOG(4) << "debug info "
<< graph->HumanReadableInfo(false /*show_values*/,
true /*show_functions*/);
} }
} // namespace analysis } // namespace analysis
......
...@@ -82,7 +82,7 @@ inference_api_test(test_api_impl ...@@ -82,7 +82,7 @@ inference_api_test(test_api_impl
if(WITH_GPU AND TENSORRT_FOUND) if(WITH_GPU AND TENSORRT_FOUND)
cc_library(paddle_inference_tensorrt_subgraph_engine cc_library(paddle_inference_tensorrt_subgraph_engine
SRCS api_tensorrt_subgraph_engine.cc SRCS api_tensorrt_subgraph_engine.cc
DEPS paddle_inference_api analysis tensorrt_engine paddle_fluid_api) DEPS paddle_inference_api analysis tensorrt_engine paddle_inference_api paddle_fluid_api tensorrt_converter)
inference_api_test(test_api_tensorrt_subgraph_engine ARGS test_word2vec) inference_api_test(test_api_tensorrt_subgraph_engine ARGS test_word2vec)
endif() endif()
......
...@@ -39,7 +39,7 @@ bool PaddleInferenceAnakinPredictor::Init(const AnakinConfig &config) { ...@@ -39,7 +39,7 @@ bool PaddleInferenceAnakinPredictor::Init(const AnakinConfig &config) {
bool PaddleInferenceAnakinPredictor::Run( bool PaddleInferenceAnakinPredictor::Run(
const std::vector<PaddleTensor> &inputs, const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) { std::vector<PaddleTensor> *output_data, int batch_size) {
for (const auto &input : inputs) { for (const auto &input : inputs) {
if (input.dtype != PaddleDType::FLOAT32) { if (input.dtype != PaddleDType::FLOAT32) {
LOG(ERROR) << "Only support float type inputs. " << input.name LOG(ERROR) << "Only support float type inputs. " << input.name
......
...@@ -37,7 +37,8 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor { ...@@ -37,7 +37,8 @@ class PaddleInferenceAnakinPredictor : public PaddlePredictor {
// NOTE Unlike the native engine, the buffers of anakin engine's output_data // NOTE Unlike the native engine, the buffers of anakin engine's output_data
// should be allocated first. // should be allocated first.
bool Run(const std::vector<PaddleTensor>& inputs, bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data) override; std::vector<PaddleTensor>* output_data,
int batch_size = -1) override;
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
......
...@@ -108,7 +108,8 @@ NativePaddlePredictor::~NativePaddlePredictor() { ...@@ -108,7 +108,8 @@ NativePaddlePredictor::~NativePaddlePredictor() {
} }
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) { std::vector<PaddleTensor> *output_data,
int batch_size) {
VLOG(3) << "Predictor::predict"; VLOG(3) << "Predictor::predict";
Timer timer; Timer timer;
timer.tic(); timer.tic();
......
...@@ -38,7 +38,8 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -38,7 +38,8 @@ class NativePaddlePredictor : public PaddlePredictor {
bool Init(std::shared_ptr<framework::Scope> parent_scope); bool Init(std::shared_ptr<framework::Scope> parent_scope);
bool Run(const std::vector<PaddleTensor> &inputs, bool Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) override; std::vector<PaddleTensor> *output_data,
int batch_size = -1) override;
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace paddle { namespace paddle {
...@@ -64,16 +65,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -64,16 +65,7 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
return false; return false;
} }
// Analyze inference_program OptimizeInferenceProgram();
Argument argument;
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "transformed program:\n"
<< argument.transformed_program_desc->SerializeAsString();
VLOG(5) << "to prepare executor";
*inference_program_->Proto() = *argument.transformed_program_desc;
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables"; VLOG(5) << "to create variables";
...@@ -86,6 +78,29 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor { ...@@ -86,6 +78,29 @@ class TensorRTSubgraphPredictor : public NativePaddlePredictor {
return true; return true;
} }
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
PADDLE_ENFORCE_GT(batch_size, 0,
"TensorRT engine needs the argument batch_size set");
FLAGS_tensorrt_engine_batch_size = batch_size;
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram() {
// Analyze inference_program
Argument argument;
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "transformed program:\n"
<< argument.transformed_program_desc->SerializeAsString();
VLOG(5) << "to prepare executor";
inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc));
}
private: private:
TensorRTConfig config_; TensorRTConfig config_;
}; };
......
...@@ -98,7 +98,8 @@ class PaddlePredictor { ...@@ -98,7 +98,8 @@ class PaddlePredictor {
// responsible for the output tensor's buffer, either allocated or passed from // responsible for the output tensor's buffer, either allocated or passed from
// outside. // outside.
virtual bool Run(const std::vector<PaddleTensor>& inputs, virtual bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data) = 0; std::vector<PaddleTensor>* output_data,
int batch_size = -1) = 0;
// Clone a predictor that share the model weights, the Cloned predictor should // Clone a predictor that share the model weights, the Cloned predictor should
// be thread-safe. // be thread-safe.
......
...@@ -35,7 +35,8 @@ class DemoPredictor : public PaddlePredictor { ...@@ -35,7 +35,8 @@ class DemoPredictor : public PaddlePredictor {
LOG(INFO) << "I get other_config " << config.other_config; LOG(INFO) << "I get other_config " << config.other_config;
} }
bool Run(const std::vector<PaddleTensor> &inputs, bool Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) override { std::vector<PaddleTensor> *output_data,
int batch_size = 0) override {
LOG(INFO) << "Run"; LOG(INFO) << "Run";
return false; return false;
} }
......
...@@ -15,50 +15,79 @@ ...@@ -15,50 +15,79 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle { namespace paddle {
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
void Main(bool use_gpu) { void CompareTensorRTWithFluid(bool enable_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = enable_tensorrt;
//# 1. Create PaddlePredictor with a config. //# 1. Create PaddlePredictor with a config.
TensorRTConfig config; NativeConfig config0;
config.model_dir = FLAGS_dirname + "word2vec.inference.model"; config0.model_dir = FLAGS_dirname + "word2vec.inference.model";
config.use_gpu = use_gpu; config0.use_gpu = true;
config.fraction_of_gpu_memory = 0.15; config0.fraction_of_gpu_memory = 0.3;
config.device = 0; config0.device = 0;
auto predictor =
TensorRTConfig config1;
config1.model_dir = FLAGS_dirname + "word2vec.inference.model";
config1.use_gpu = true;
config1.fraction_of_gpu_memory = 0.3;
config1.device = 0;
auto predictor0 =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
auto predictor1 =
CreatePaddlePredictor<TensorRTConfig, CreatePaddlePredictor<TensorRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config); PaddleEngineKind::kAutoMixedTensorRT>(config1);
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 1; batch_id++) {
//# 2. Prepare input. //# 2. Prepare input.
int64_t data[4] = {1, 2, 3, 4}; std::vector<int64_t> data(20);
for (int i = 0; i < 20; i++) data[i] = i;
PaddleTensor tensor{.name = "", PaddleTensor tensor{
.shape = std::vector<int>({4, 1}), .name = "",
.data = PaddleBuf(data, sizeof(data)), .shape = std::vector<int>({10, 1}),
.data = PaddleBuf(data.data(), data.size() * sizeof(int64_t)),
.dtype = PaddleDType::INT64}; .dtype = PaddleDType::INT64};
// For simplicity, we set all the slots with the same data. // For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> slots(4, tensor); std::vector<PaddleTensor> slots(4, tensor);
//# 3. Run //# 3. Run
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs0;
CHECK(predictor->Run(slots, &outputs)); std::vector<PaddleTensor> outputs1;
CHECK(predictor0->Run(slots, &outputs0));
CHECK(predictor1->Run(slots, &outputs1, 10));
//# 4. Get output. //# 4. Get output.
ASSERT_EQ(outputs.size(), 1UL); ASSERT_EQ(outputs0.size(), 1UL);
LOG(INFO) << "output buffer size: " << outputs.front().data.length(); ASSERT_EQ(outputs1.size(), 1UL);
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. const size_t num_elements = outputs0.front().data.length() / sizeof(float);
for (size_t i = 0; i < std::min(5UL, num_elements); i++) { const size_t num_elements1 = outputs1.front().data.length() / sizeof(float);
LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i]; EXPECT_EQ(num_elements, num_elements1);
auto *data0 = static_cast<float *>(outputs0.front().data.data());
auto *data1 = static_cast<float *>(outputs1.front().data.data());
ASSERT_GT(num_elements, 0UL);
for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) {
EXPECT_NEAR(data0[i], data1[i], 1e-3);
} }
} }
} }
TEST(paddle_inference_api_tensorrt_subgraph_engine, main) { Main(true); } TEST(paddle_inference_api_tensorrt_subgraph_engine, without_tensorrt) {
CompareTensorRTWithFluid(false);
}
TEST(paddle_inference_api_tensorrt_subgraph_engine, with_tensorrt) {
CompareTensorRTWithFluid(true);
}
} // namespace paddle } // namespace paddle
...@@ -93,6 +93,10 @@ class OpConverter { ...@@ -93,6 +93,10 @@ class OpConverter {
framework::Scope* scope_{nullptr}; framework::Scope* scope_{nullptr};
}; };
} // namespace tensorrt
} // namespace inference
} // namespace paddle
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ #define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \ struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \
trt_##op_type__##_converter() { \ trt_##op_type__##_converter() { \
...@@ -111,7 +115,3 @@ class OpConverter { ...@@ -111,7 +115,3 @@ class OpConverter {
extern int TouchConverterRegister_##op_type__(); \ extern int TouchConverterRegister_##op_type__(); \
static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \ static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \
TouchConverterRegister_##op_type__(); TouchConverterRegister_##op_type__();
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -26,18 +26,20 @@ namespace paddle { ...@@ -26,18 +26,20 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
void TensorRTEngine::Build(const DescType& paddle_model) { void TensorRTEngine::Build(const DescType &paddle_model) {
PADDLE_ENFORCE(false, "not implemented"); PADDLE_ENFORCE(false, "not implemented");
} }
void TensorRTEngine::Execute(int batch_size) { void TensorRTEngine::Execute(int batch_size) {
std::vector<void*> buffers; batch_size_ = batch_size;
for (auto& buf : buffers_) { std::vector<void *> buffers;
for (auto &buf : buffers_) {
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated"); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
PADDLE_ENFORCE_GT(buf.max_size, 0); PADDLE_ENFORCE_GT(buf.max_size, 0);
PADDLE_ENFORCE(buf.device == DeviceType::GPU); PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buffers.push_back(buf.buffer); buffers.push_back(buf.buffer);
} }
PADDLE_ENFORCE_NOT_NULL(stream_);
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr); infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
cudaStreamSynchronize(*stream_); cudaStreamSynchronize(*stream_);
} }
...@@ -45,7 +47,7 @@ void TensorRTEngine::Execute(int batch_size) { ...@@ -45,7 +47,7 @@ void TensorRTEngine::Execute(int batch_size) {
TensorRTEngine::~TensorRTEngine() { TensorRTEngine::~TensorRTEngine() {
cudaStreamSynchronize(*stream_); cudaStreamSynchronize(*stream_);
// clean buffer // clean buffer
for (auto& buf : buffers_) { for (auto &buf : buffers_) {
if (buf.device == DeviceType::GPU && buf.buffer != nullptr) { if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer)); PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
buf.buffer = nullptr; buf.buffer = nullptr;
...@@ -70,32 +72,37 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -70,32 +72,37 @@ void TensorRTEngine::FreezeNetwork() {
// allocate GPU buffers. // allocate GPU buffers.
buffers_.resize(buffer_sizes_.size()); buffers_.resize(buffer_sizes_.size());
for (auto& item : buffer_sizes_) { for (auto &item : buffer_sizes_) {
// The output buffers are not set in the network building phrase, need to
// infer from the TesorRT network.
if (item.second == 0) { if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
auto dims = infer_engine_->getBindingDimensions(slot_offset); auto dims = infer_engine_->getBindingDimensions(slot_offset);
item.second = kDataTypeSize[static_cast<int>( item.second = kDataTypeSize[static_cast<int>(
infer_engine_->getBindingDataType(slot_offset))] * infer_engine_->getBindingDataType(slot_offset))] *
analysis::AccuDims(dims.d, dims.nbDims); analysis::AccuDims(dims.d, dims.nbDims);
PADDLE_ENFORCE_GT(item.second, 0);
} }
auto& buf = buffer(item.first);
auto &buf = buffer(item.first);
buf.max_size = item.second * max_batch_;
CHECK(buf.buffer == nullptr); // buffer should be allocated only once. CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second)); PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, buf.max_size));
VLOG(4) << "buffer malloc " << item.first << " " << item.second << " " PADDLE_ENFORCE_LE(buf.max_size, 1 << 30); // 10G
<< buf.buffer; // buf.size will changed in the runtime.
buf.size = buf.max_size = item.second; buf.size = 0;
buf.device = DeviceType::GPU; buf.device = DeviceType::GPU;
} }
} }
nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
nvinfer1::DataType dtype, nvinfer1::DataType dtype,
const nvinfer1::Dims& dims) { const nvinfer1::Dims &dims) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name); name);
PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
auto* input = infer_network_->addInput(name.c_str(), dtype, dims); auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
PADDLE_ENFORCE(input, "infer network add input %s failed", name); PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
analysis::AccuDims(dims.d, dims.nbDims); analysis::AccuDims(dims.d, dims.nbDims);
...@@ -104,12 +111,12 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, ...@@ -104,12 +111,12 @@ nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
return input; return input;
} }
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
const std::string& name) { const std::string &name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
name); name);
auto* output = layer->getOutput(offset); auto *output = layer->getOutput(offset);
SetITensor(name, output); SetITensor(name, output);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
...@@ -121,11 +128,11 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, ...@@ -121,11 +128,11 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
buffer_sizes_[name] = 0; buffer_sizes_[name] = 0;
} }
void TensorRTEngine::DeclareOutput(const std::string& name) { void TensorRTEngine::DeclareOutput(const std::string &name) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
name); name);
auto* output = TensorRTEngine::GetITensor(name); auto *output = TensorRTEngine::GetITensor(name);
PADDLE_ENFORCE(output != nullptr); PADDLE_ENFORCE(output != nullptr);
output->setName(name.c_str()); output->setName(name.c_str());
PADDLE_ENFORCE(!output->isNetworkInput()); PADDLE_ENFORCE(!output->isNetworkInput());
...@@ -135,38 +142,45 @@ void TensorRTEngine::DeclareOutput(const std::string& name) { ...@@ -135,38 +142,45 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
buffer_sizes_[name] = 0; buffer_sizes_[name] = 0;
} }
void* TensorRTEngine::GetOutputInGPU(const std::string& name) { void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
return buffer(name).buffer; return buffer(name).buffer;
} }
void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst, void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
size_t max_size) { size_t max_size) {
// determine data size // determine data size
auto it = buffer_sizes_.find(name); auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0); PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second); PADDLE_ENFORCE_GE(max_size, it->second);
auto& buf = buffer(name); auto &buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second, PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
cudaMemcpyDeviceToDevice, *stream_), cudaMemcpyDeviceToDevice, *stream_),
0); 0);
} }
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
size_t max_size) { size_t max_size) {
VLOG(4) << "get output in cpu";
auto &buf = buffer(name);
// Update needed buffer size.
auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
auto dims = infer_engine_->getBindingDimensions(slot_offset);
buf.size = kDataTypeSize[static_cast<int>(
infer_engine_->getBindingDataType(slot_offset))] *
analysis::AccuDims(dims.d, dims.nbDims);
PADDLE_ENFORCE_LE(buf.size, buf.max_size);
// determine data size // determine data size
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, it->second, // DEBUG
cudaMemcpyDeviceToHost, *stream_)); memset(dst, 0, buf.size);
PADDLE_ENFORCE_EQ(
0, cudaMemcpy(dst, buf.buffer, buf.size, cudaMemcpyDeviceToHost));
} }
Buffer& TensorRTEngine::buffer(const std::string& name) { Buffer &TensorRTEngine::buffer(const std::string &name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name); auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE(it != buffer_sizes_.end());
...@@ -174,19 +188,23 @@ Buffer& TensorRTEngine::buffer(const std::string& name) { ...@@ -174,19 +188,23 @@ Buffer& TensorRTEngine::buffer(const std::string& name) {
return buffers_[slot_offset]; return buffers_[slot_offset];
} }
void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data, void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
size_t size) { size_t size) {
auto& buf = buffer(name); auto &buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer); PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE_NOT_NULL(stream_);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU); PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buf.size = size;
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size, PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyHostToDevice, *stream_)); cudaMemcpyHostToDevice, *stream_));
} }
void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data, void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
size_t size) { size_t size) {
auto& buf = buffer(name); auto &buf = buffer(name);
buf.size = size;
PADDLE_ENFORCE_NOT_NULL(buf.buffer); PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small"); PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU); PADDLE_ENFORCE(buf.device == DeviceType::GPU);
...@@ -194,15 +212,15 @@ void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data, ...@@ -194,15 +212,15 @@ void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data,
cudaMemcpyDeviceToDevice, *stream_)); cudaMemcpyDeviceToDevice, *stream_));
} }
void TensorRTEngine::SetITensor(const std::string& name, void TensorRTEngine::SetITensor(const std::string &name,
nvinfer1::ITensor* tensor) { nvinfer1::ITensor *tensor) {
PADDLE_ENFORCE(tensor != nullptr); PADDLE_ENFORCE(tensor != nullptr);
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s", PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
name); name);
itensor_map_[name] = tensor; itensor_map_[name] = tensor;
} }
nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) { nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name); PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
return itensor_map_[name]; return itensor_map_[name];
} }
......
...@@ -57,7 +57,9 @@ class TensorRTEngine : public EngineBase { ...@@ -57,7 +57,9 @@ class TensorRTEngine : public EngineBase {
: max_batch_(max_batch), : max_batch_(max_batch),
max_workspace_(max_workspace), max_workspace_(max_workspace),
stream_(stream ? stream : &default_stream_), stream_(stream ? stream : &default_stream_),
logger_(logger) {} logger_(logger) {
cudaStreamCreate(&default_stream_);
}
virtual ~TensorRTEngine(); virtual ~TensorRTEngine();
...@@ -121,6 +123,9 @@ class TensorRTEngine : public EngineBase { ...@@ -121,6 +123,9 @@ class TensorRTEngine : public EngineBase {
int max_batch_; int max_batch_;
// the max memory size the engine uses // the max memory size the engine uses
int max_workspace_; int max_workspace_;
// batch size of the current data, will be updated each Executation.
int batch_size_{-1};
cudaStream_t* stream_; cudaStream_t* stream_;
// If stream_ is not set from outside, hold its own stream. // If stream_ is not set from outside, hold its own stream.
cudaStream_t default_stream_; cudaStream_t default_stream_;
......
...@@ -103,6 +103,10 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { ...@@ -103,6 +103,10 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
LOG(INFO) << "to get output"; LOG(INFO) << "to get output";
float y_cpu[2] = {-1., -1.}; float y_cpu[2] = {-1., -1.};
auto dims = engine_->GetITensor("y")->getDimensions();
ASSERT_EQ(dims.nbDims, 3);
ASSERT_EQ(dims.d[0], 2);
ASSERT_EQ(dims.d[1], 1);
engine_->GetOutputInCPU("y", &y_cpu[0], sizeof(float) * 2); engine_->GetOutputInCPU("y", &y_cpu[0], sizeof(float) * 2);
ASSERT_EQ(y_cpu[0], 4.5); ASSERT_EQ(y_cpu[0], 4.5);
ASSERT_EQ(y_cpu[1], 14.5); ASSERT_EQ(y_cpu[1], 14.5);
......
...@@ -168,6 +168,8 @@ function(op_library TARGET) ...@@ -168,6 +168,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize") elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "tensorrt_engine_op")
message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
else() else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif() endif()
...@@ -237,9 +239,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) ...@@ -237,9 +239,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax) op_library(softmax_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax) op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND) if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine) op_library(tensorrt_engine_op DEPS tensorrt_engine tensorrt_converter)
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter DEPS tensorrt_engine_op
analysis) analysis)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
......
...@@ -24,15 +24,16 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -24,15 +24,16 @@ class AucOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("Predict"),
PADDLE_ENFORCE(ctx->HasInput("Indices"), "Input of Out should not be null.");
"Input of Indices should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input of Label should not be null."); "Input of Label should not be null.");
auto inference_height = ctx->GetInputDim("Out")[0]; auto predict_width = ctx->GetInputDim("Predict")[1];
PADDLE_ENFORCE_EQ(predict_width, 2, "Only support binary classification");
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0]; auto label_height = ctx->GetInputDim("Label")[0];
PADDLE_ENFORCE_EQ(inference_height, label_height, PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height."); "Out and Label should have same height.");
int num_thres = ctx->Attrs().Get<int>("num_thresholds"); int num_thres = ctx->Attrs().Get<int>("num_thresholds");
...@@ -43,14 +44,14 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -43,14 +44,14 @@ class AucOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FPOut", {num_thres}); ctx->SetOutputDim("FPOut", {num_thres});
ctx->SetOutputDim("FNOut", {num_thres}); ctx->SetOutputDim("FNOut", {num_thres});
ctx->ShareLoD("Out", /*->*/ "AUC"); ctx->ShareLoD("Predict", /*->*/ "AUC");
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()), framework::ToDataType(ctx.Input<Tensor>("Predict")->type()),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -58,18 +59,13 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -58,18 +59,13 @@ class AucOp : public framework::OperatorWithKernel {
class AucOpMaker : public framework::OpProtoAndCheckerMaker { class AucOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Out", AddInput("Predict",
"A floating point 2D tensor, values are in the range [0, 1]." "A floating point 2D tensor with shape [batch_size, 2], values "
"Each row is sorted in descending order. This input should be the" "are in the range [0, 1]."
"output of topk."
"Typically, this tensor indicates the probability of each label"); "Typically, this tensor indicates the probability of each label");
AddInput("Indices",
"An int 2D tensor, indicating the indices of original"
"tensor before sorting. Typically, this tensor indicates which "
"label the probability stands for.");
AddInput("Label", AddInput("Label",
"A 2D int tensor indicating the label of the training data." "A 2D int tensor indicating the label of the training data. "
"The height is batch size and width is always 1."); "shape: [batch_size, 1]");
AddInput("TP", "True-Positive value."); AddInput("TP", "True-Positive value.");
AddInput("FP", "False-Positive value."); AddInput("FP", "False-Positive value.");
AddInput("TN", "True-Negative value."); AddInput("TN", "True-Negative value.");
......
...@@ -31,7 +31,7 @@ template <typename DeviceContext, typename T> ...@@ -31,7 +31,7 @@ template <typename DeviceContext, typename T>
class AucKernel : public framework::OpKernel<T> { class AucKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Out"); auto* predict = ctx.Input<Tensor>("Predict");
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* auc = ctx.Output<Tensor>("AUC"); auto* auc = ctx.Output<Tensor>("AUC");
// Only use output var for now, make sure it's persistable and // Only use output var for now, make sure it's persistable and
...@@ -41,24 +41,24 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -41,24 +41,24 @@ class AucKernel : public framework::OpKernel<T> {
auto* true_negative = ctx.Output<Tensor>("TNOut"); auto* true_negative = ctx.Output<Tensor>("TNOut");
auto* false_negative = ctx.Output<Tensor>("FNOut"); auto* false_negative = ctx.Output<Tensor>("FNOut");
float* auc_data = auc->mutable_data<float>(ctx.GetPlace()); auto* auc_data = auc->mutable_data<double>(ctx.GetPlace());
std::string curve = ctx.Attr<std::string>("curve"); std::string curve = ctx.Attr<std::string>("curve");
int num_thresholds = ctx.Attr<int>("num_thresholds"); int num_thresholds = ctx.Attr<int>("num_thresholds");
std::vector<float> thresholds_list; std::vector<double> thresholds_list;
thresholds_list.reserve(num_thresholds); thresholds_list.reserve(num_thresholds);
for (int i = 1; i < num_thresholds - 1; i++) { for (int i = 1; i < num_thresholds - 1; i++) {
thresholds_list[i] = static_cast<float>(i) / (num_thresholds - 1); thresholds_list[i] = static_cast<double>(i) / (num_thresholds - 1);
} }
const float kEpsilon = 1e-7; const double kEpsilon = 1e-7;
thresholds_list[0] = 0.0f - kEpsilon; thresholds_list[0] = 0.0f - kEpsilon;
thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
size_t batch_size = inference->dims()[0]; size_t batch_size = predict->dims()[0];
size_t inference_width = inference->dims()[1]; size_t inference_width = predict->dims()[1];
const T* inference_data = inference->data<T>(); const T* inference_data = predict->data<T>();
const int64_t* label_data = label->data<int64_t>(); const auto* label_data = label->data<int64_t>();
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace()); auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
...@@ -66,20 +66,19 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -66,20 +66,19 @@ class AucKernel : public framework::OpKernel<T> {
auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace()); auto* fp_data = false_positive->mutable_data<int64_t>(ctx.GetPlace());
for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
// caculate TP, FN, TN, FP for current thresh // calculate TP, FN, TN, FP for current thresh
int64_t tp = 0, fn = 0, tn = 0, fp = 0; int64_t tp = 0, fn = 0, tn = 0, fp = 0;
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; i++) {
// NOTE: label_data used as bool, labels >0 will be treated as true. // NOTE: label_data used as bool, labels > 0 will be treated as true.
if (label_data[i]) { if (label_data[i]) {
// use first(max) data in each row if (inference_data[i * inference_width + 1] >=
if (inference_data[i * inference_width] >=
(thresholds_list[idx_thresh])) { (thresholds_list[idx_thresh])) {
tp++; tp++;
} else { } else {
fn++; fn++;
} }
} else { } else {
if (inference_data[i * inference_width] >= if (inference_data[i * inference_width + 1] >=
(thresholds_list[idx_thresh])) { (thresholds_list[idx_thresh])) {
fp++; fp++;
} else { } else {
...@@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -94,21 +93,21 @@ class AucKernel : public framework::OpKernel<T> {
fp_data[idx_thresh] += fp; fp_data[idx_thresh] += fp;
} }
// epsilon to avoid divide by zero. // epsilon to avoid divide by zero.
float epsilon = 1e-6; double epsilon = 1e-6;
// Riemann sum to caculate auc. // Riemann sum to caculate auc.
Tensor tp_rate, fp_rate, rec_rate; Tensor tp_rate, fp_rate, rec_rate;
tp_rate.Resize({num_thresholds}); tp_rate.Resize({num_thresholds});
fp_rate.Resize({num_thresholds}); fp_rate.Resize({num_thresholds});
rec_rate.Resize({num_thresholds}); rec_rate.Resize({num_thresholds});
float* tp_rate_data = tp_rate.mutable_data<float>(ctx.GetPlace()); auto* tp_rate_data = tp_rate.mutable_data<double>(ctx.GetPlace());
float* fp_rate_data = fp_rate.mutable_data<float>(ctx.GetPlace()); auto* fp_rate_data = fp_rate.mutable_data<double>(ctx.GetPlace());
float* rec_rate_data = rec_rate.mutable_data<float>(ctx.GetPlace()); auto* rec_rate_data = rec_rate.mutable_data<double>(ctx.GetPlace());
for (int i = 0; i < num_thresholds; i++) { for (int i = 0; i < num_thresholds; i++) {
tp_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) / tp_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
(tp_data[i] + fn_data[i] + epsilon); (tp_data[i] + fn_data[i] + epsilon);
fp_rate_data[i] = fp_rate_data[i] =
static_cast<float>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon); static_cast<double>(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
rec_rate_data[i] = (static_cast<float>(tp_data[i]) + epsilon) / rec_rate_data[i] = (static_cast<double>(tp_data[i]) + epsilon) /
(tp_data[i] + fp_data[i] + epsilon); (tp_data[i] + fp_data[i] + epsilon);
} }
*auc_data = 0.0f; *auc_data = 0.0f;
......
if(NOT WITH_DISTRIBUTE)
return()
endif()
if(WITH_GRPC)
set(cc_generic_services "false")
else()
set(cc_generic_services "true")
endif()
configure_file(send_recv.proto.in ${CMAKE_CURRENT_SOURCE_DIR}/send_recv.proto @ONLY)
if(WITH_GRPC) if(WITH_GRPC)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
selected_rows memory) PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS grpc_serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(grpc_serde_test SRCS grpc_serde_test.cc
cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(grpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_grpc cc_test(grpc_server_test SRCS rpc_server_test.cc
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op SERIAL)
proto_desc lookup_table_op SERIAL)
return() return()
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc set_source_files_properties(brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
brpc_library(sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows memory) DEPS lod_tensor selected_rows memory)
find_library(OPENSSL_CRYPTO_LIBRARY_STATIC NAMES libcrypto.so) set(brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy)
ADD_LIBRARY(crypto SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET crypto PROPERTY IMPORTED_LOCATION ${OPENSSL_CRYPTO_LIBRARY_STATIC})
find_library(OPENSSL_SSL_LIBRARY_STATIC NAMES libssl.so) cc_test(brpc_server_test SRCS rpc_server_test.cc
ADD_LIBRARY(ssl SHARED IMPORTED GLOBAL) DEPS ${brpc_test_depends} SERIAL)
SET_PROPERTY(TARGET ssl PROPERTY IMPORTED_LOCATION ${OPENSSL_SSL_LIBRARY_STATIC})
cc_test(brpc_server_test SRCS rpc_server_test.cc DEPS sendrecvop_brpc cc_test(brpc_serde_test SRCS brpc_serde_test.cc
brpc protobuf leveldb gflags glog DEPS ${brpc_test_depends} SERIAL)
protobuf executor proto_desc lookup_table_op snappystream snappy ssl crypto SERIAL)
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
// file and did some modifications so that we can send gRPC // file and did some modifications so that we can send gRPC
// requests without too much copying of the tensor data. // requests without too much copying of the tensor data.
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace grpc { namespace grpc {
// A ZeroCopyInputStream that reads from grpc_byte_buffer // A ZeroCopyInputStream that reads from grpc_byte_buffer
...@@ -107,25 +108,6 @@ class GrpcBufferReader final ...@@ -107,25 +108,6 @@ class GrpcBufferReader final
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
// A ZeroCopyInputStream that reads from a grpc::ByteBuffer. // A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
class GrpcByteBufferSource class GrpcByteBufferSource
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "glog/logging.h" // For VLOG #include "glog/logging.h" // For VLOG
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/request_handler.h" #include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
......
...@@ -38,7 +38,10 @@ limitations under the License. */ ...@@ -38,7 +38,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_client.h" #include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
...@@ -46,23 +49,6 @@ namespace paddle { ...@@ -46,23 +49,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
struct VarHandle {
// RPC endpoint.
std::string ep;
const platform::DeviceContext* ctx;
const framework::Scope* scope;
// Variable name.
std::string name;
// RPC method name.
std::string method;
std::string String() const {
std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]";
return s.str();
}
};
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include <sys/time.h>
#include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
void* payload = nullptr;
size_t payload_size;
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CUDAPinnedPlace cuda_pinned;
memory::Free(cuda_pinned, backing);
};
#endif
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
#endif
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
slices[3] = ::grpc::Slice(
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle {
namespace operators {
namespace distributed {
typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var);
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -21,8 +21,10 @@ limitations under the License. */ ...@@ -21,8 +21,10 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -84,7 +86,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2); // operators::distributed::DeserializeFromByteBuffer(msg, ctx, &var2);
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::distributed::VariableResponse resp(&scope, &ctx); operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
framework::Variable* var2 = resp.GetVar(); framework::Variable* var2 = resp.GetVar();
...@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) { ...@@ -171,7 +173,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
// deserialize zero-copy // deserialize zero-copy
framework::Scope scope; framework::Scope scope;
scope.Var("myvar"); scope.Var("myvar");
operators::distributed::VariableResponse resp(&scope, &ctx); operators::distributed::GRPCVariableResponse resp(&scope, &ctx);
if (from_type == 0) { if (from_type == 0) {
EXPECT_EQ(resp.Parse(msg), 0); EXPECT_EQ(resp.Parse(msg), 0);
} else { } else {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include <string> #include <string>
#include "paddle/fluid/operators/distributed/grpc_serde.h"
#include "paddle/fluid/operators/distributed/grpc_server.h" #include "paddle/fluid/operators/distributed/grpc_server.h"
using ::grpc::ServerAsyncResponseWriter; using ::grpc::ServerAsyncResponseWriter;
...@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase { ...@@ -84,7 +85,7 @@ class RequestSend final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), request_handler->dev_ctx(),
!request_handler->sync_mode())); !request_handler->sync_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable); int method_id = static_cast<int>(distributed::GrpcMethod::kSendVariable);
...@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase { ...@@ -109,7 +110,7 @@ class RequestSend final : public RequestBase {
protected: protected:
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::shared_ptr<VariableResponse> request_; std::shared_ptr<GRPCVariableResponse> request_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
...@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -161,7 +162,7 @@ class RequestPrefetch final : public RequestBase {
: RequestBase(service, cq, request_handler, req_id), : RequestBase(service, cq, request_handler, req_id),
responder_(&ctx_), responder_(&ctx_),
local_scope_(nullptr) { local_scope_(nullptr) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx(), true)); request_handler->dev_ctx(), true));
int method_id = int method_id =
static_cast<int>(distributed::GrpcMethod::kPrefetchVariable); static_cast<int>(distributed::GrpcMethod::kPrefetchVariable);
...@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -194,7 +195,7 @@ class RequestPrefetch final : public RequestBase {
} }
protected: protected:
std::shared_ptr<VariableResponse> request_; std::shared_ptr<GRPCVariableResponse> request_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* local_scope_; framework::Scope* local_scope_;
...@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -206,7 +207,7 @@ class RequestCheckpointNotify final : public RequestBase {
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx())); request_handler->dev_ctx()));
int method_id = int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify); static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
...@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -234,7 +235,7 @@ class RequestCheckpointNotify final : public RequestBase {
} }
protected: protected:
std::shared_ptr<VariableResponse> request_; std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
......
...@@ -23,8 +23,7 @@ ...@@ -23,8 +23,7 @@
#include <grpc++/impl/codegen/stub_options.h> #include <grpc++/impl/codegen/stub_options.h>
#include <grpc++/impl/codegen/sync_stream.h> #include <grpc++/impl/codegen/sync_stream.h>
#include <grpc++/support/byte_buffer.h> #include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow // NOTE: This method was originally created by tensorflow
...@@ -42,17 +41,18 @@ class ServerContext; ...@@ -42,17 +41,18 @@ class ServerContext;
// Support parsing/unparsing of tensorflow::VariableResponse. // Support parsing/unparsing of tensorflow::VariableResponse.
// Wire-format is identical to RecvVariableResponse. // Wire-format is identical to RecvVariableResponse.
template <> template <>
class SerializationTraits<paddle::operators::distributed::VariableResponse> { class SerializationTraits<
paddle::operators::distributed::GRPCVariableResponse> {
public: public:
static Status Serialize( static Status Serialize(
const paddle::operators::distributed::VariableResponse& msg, const paddle::operators::distributed::GRPCVariableResponse& msg,
grpc_byte_buffer** bp, bool* own_buffer) { grpc_byte_buffer** bp, bool* own_buffer) {
PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!"); PADDLE_ENFORCE(false, "SerializationTraits::Serialize not implemented!");
return Status(); return Status();
} }
static Status Deserialize( static Status Deserialize(
grpc_byte_buffer* buffer, grpc_byte_buffer* buffer,
paddle::operators::distributed::VariableResponse* msg, paddle::operators::distributed::GRPCVariableResponse* msg,
int max_message_size = INT_MAX) { int max_message_size = INT_MAX) {
if (buffer == nullptr) { if (buffer == nullptr) {
return Status(StatusCode::INTERNAL, "No payload"); return Status(StatusCode::INTERNAL, "No payload");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/operators/distributed/grpc_variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
namespace distributed {
enum WireType {
WIRETYPE_VARINT = 0,
WIRETYPE_LENGTH_DELIMITED = 2,
};
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
int GRPCVariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
bool ParseLodData(::google::protobuf::io::CodedInputStream* input,
std::vector<int64_t>* lod) {
while (true) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return tag;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int GRPCVariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return tag;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!ProcSerializedField(tag, &input, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
case sendrecv::VariableMessage::kProfileFieldNumber: {
uint64_t profiling = 0;
if (!input.ReadVarint64(&profiling)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
}
}
return 0;
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/grpc_bytebuffer_stream.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
namespace paddle {
namespace operators {
namespace distributed {
class GRPCVariableResponse : public VariableResponse {
public:
GRPCVariableResponse(const framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
bool create_scope = false)
: VariableResponse(scope, dev_ctx, create_scope) {}
virtual ~GRPCVariableResponse() {}
int Parse(Source* source) override;
// return:
// 0:ok.
// -1: unkown error.
// other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer);
};
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
...@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; ...@@ -51,6 +51,23 @@ constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
class RPCServer; class RPCServer;
struct VarHandle {
// RPC endpoint.
std::string ep;
const platform::DeviceContext* ctx;
const framework::Scope* scope;
// Variable name.
std::string name;
// RPC method name.
std::string method;
std::string String() const {
std::ostringstream s;
s << method << " name:[" << name << "], ep:[" << ep << "]";
return s.str();
}
};
class RequestHandler { class RequestHandler {
public: public:
explicit RequestHandler(bool sync_mode) explicit RequestHandler(bool sync_mode)
......
...@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -53,7 +53,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Sync // Sync
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message"; VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend); rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == BEGIN_PASS_MESSAGE) { } else if (varname == BEGIN_PASS_MESSAGE) {
VLOG(3) << "sync: recv begin pass message"; VLOG(3) << "sync: recv begin pass message";
...@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -65,8 +65,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG(3) << "sync: processing received var: " << varname; VLOG(3) << "sync: processing received var: " << varname;
if (invar == nullptr) { if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname; LOG(FATAL) << "sync: Can not find server side var: " << varname;
PADDLE_THROW("sync: Can not find server side var");
return false; return false;
} }
if (invar->IsType<framework::SelectedRows>()) { if (invar->IsType<framework::SelectedRows>()) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under
the Apache License, Version 2.0 (the "License"); you may not use this file the Apache License, Version 2.0 (the "License"); you may not use this file
except in compliance with the License. except in compliance with the License.
...@@ -14,7 +15,7 @@ limitations under the License. */ ...@@ -14,7 +15,7 @@ limitations under the License. */
syntax = "proto3"; syntax = "proto3";
package sendrecv; package sendrecv;
// option cc_generic_services = true; option cc_generic_services = @cc_generic_services@;
service SendRecvService { service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors. // For parameter server round-robin like hashing, do not split tensors.
......
...@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <sys/time.h> #include <sys/time.h>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
#include "paddle/fluid/operators/distributed/proto_encoder_helper.h"
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -34,6 +28,11 @@ namespace distributed { ...@@ -34,6 +28,11 @@ namespace distributed {
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
void* GetVarPayLoad(const std::string varname, int64_t size) {
platform::CUDAPinnedPlace cuda_pinned;
return memory::Alloc(cuda_pinned, size);
}
void GetTensorPayload(framework::Variable* var, void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request, const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) { void** payload, size_t* payload_size) {
...@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var, ...@@ -58,15 +57,17 @@ void GetTensorPayload(framework::Variable* var,
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
platform::CUDAPinnedPlace cuda_pinned; // platform::CUDAPinnedPlace cuda_pinned;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx); auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type()); auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
*payload = memory::Alloc(cuda_pinned, copy_size); *payload = GetVarPayLoad(request->varname(), copy_size);
platform::CUDAPinnedPlace cuda_pinned;
memory::Copy(cuda_pinned, *payload, memory::Copy(cuda_pinned, *payload,
boost::get<platform::CUDAPlace>(tensor.place()), boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(tensor.data<void>()), copy_size, reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
gpu_dev_ctx.stream()); gpu_dev_ctx.stream());
ctx.Wait(); ctx.Wait();
#endif #endif
} else { } else {
...@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var, ...@@ -91,10 +92,11 @@ void GetSelectedRowsPayload(framework::Variable* var,
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CUDAPinnedPlace cuda_pinned;
auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx); auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type()); auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
*payload = memory::Alloc(cuda_pinned, copy_size); *payload = GetVarPayLoad(request->varname(), copy_size);
platform::CUDAPinnedPlace cuda_pinned;
memory::Copy(cuda_pinned, *payload, memory::Copy(cuda_pinned, *payload,
boost::get<platform::CUDAPlace>(tensor->place()), boost::get<platform::CUDAPlace>(tensor->place()),
reinterpret_cast<const void*>(tensor->data<void>()), copy_size, reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
...@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var, ...@@ -107,126 +109,6 @@ void GetSelectedRowsPayload(framework::Variable* var,
*payload_size = tensor->numel() * framework::SizeOfType(tensor->type()); *payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
} }
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
VarMsg request;
void* payload = nullptr;
size_t payload_size;
request.set_varname(name);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
if (platform::IsProfileEnabled()) {
request.set_profile(platform::kEnableProfiler);
} else {
request.set_profile(platform::kDisableProfiler);
}
}
if (!out_name.empty()) {
request.set_out_varname(out_name);
}
if (var->IsType<framework::LoDTensor>()) {
request.set_type(::sendrecv::LOD_TENSOR);
GetTensorPayload(var, ctx, &request, &payload, &payload_size);
} else if (var->IsType<framework::SelectedRows>()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
#ifdef PADDLE_WITH_CUDA
} else if (var->IsType<ncclUniqueId>()) {
request.set_type(::sendrecv::NCCL_ID);
#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
// GPU data is copied to CPU buffer when sending,
// free the buffer when possible.
destroy_callback = [](void* backing) {
platform::CUDAPinnedPlace cuda_pinned;
memory::Free(cuda_pinned, backing);
};
#endif
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
#ifdef PADDLE_WITH_CUDA
if (var->IsType<ncclUniqueId>()) {
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
NCCL_UNIQUE_ID_BYTES);
const ncclUniqueId& uid = var->Get<ncclUniqueId>();
e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
// for serialize NCCL_ID
::grpc::Slice slices(e.size());
memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
::grpc::ByteBuffer tmp(&slices, 1);
msg->Swap(&tmp);
return;
}
#endif
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer
slices[0] = ::grpc::Slice(e.size());
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
slices[1] = ::grpc::Slice(
grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
static_cast<char*>(payload)),
::grpc::Slice::STEAL_REF);
if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
size_t rows_memory_size =
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
slices[3] = ::grpc::Slice(
grpc_slice_new_with_user_data(
const_cast<void*>(
reinterpret_cast<const void*>(slr->rows().data())),
rows_memory_size, [](void* backing) {},
const_cast<char*>(
reinterpret_cast<const char*>(slr->rows().data()))),
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
operators::distributed::VariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -25,24 +25,21 @@ limitations under the License. */ ...@@ -25,24 +25,21 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
typedef void (*DestroyCallback)(void*); using VarMsg = sendrecv::VariableMessage;
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx, VarMsg* request,
::grpc::ByteBuffer* msg, void** payload, size_t* payload_size);
const std::string& out_varname = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx, VarMsg* request,
const framework::Scope* scope, void** payload, size_t* payload_size);
framework::Variable** var);
inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) {
switch (type) { switch (type) {
......
...@@ -13,50 +13,20 @@ ...@@ -13,50 +13,20 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/distributed/variable_response.h" #include "paddle/fluid/operators/distributed/variable_response.h"
#include <string>
#include <utility>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA
#include <nccl.h>
#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h" #include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
enum WireType { bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
WIRETYPE_VARINT = 0, const platform::DeviceContext& dev_ctx,
WIRETYPE_LENGTH_DELIMITED = 2, platform::Place place, void* dest,
}; int64_t size) {
inline int GetTagFieldNumber(uint32_t tag) { return tag >> 3; }
inline WireType GetTagWireType(uint32_t tag) {
return static_cast<WireType>(tag & 0x7);
}
bool ReadVarintSizeAsInt(::google::protobuf::io::CodedInputStream* input,
int* result) {
uint64_t v;
if (input->ReadVarint64(&v) && v <= static_cast<uint64_t>(INT_MAX)) {
*result = static_cast<int>(v);
return true;
} else {
return false;
}
}
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int size) {
const void* data = NULL; const void* data = NULL;
int size_to_write = 0; int size_to_write = 0;
int length = size; int64_t length = size;
int total_written = 0; int total_written = 0;
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
...@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData( ...@@ -194,294 +164,49 @@ bool VariableResponse::CopySelectRowsData(
return true; return true;
} }
bool ParseLodData(::google::protobuf::io::CodedInputStream* input, bool VariableResponse::ProcSerializedField(
std::vector<int64_t>* lod) { int tag, ::google::protobuf::io::CodedInputStream* input,
while (true) { int64_t num_bytes) {
auto p = input->ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
return (tag == 0);
}
switch (tag) {
case sendrecv::VariableMessage_LodData::kLodDataFieldNumber: {
uint64_t v;
if (wt == WIRETYPE_VARINT) {
if (!input->ReadVarint64(&v)) {
return false;
}
lod->push_back(v);
break;
}
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input->ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input->CurrentPosition();
while (input->CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input->ReadVarint64(&v)) {
return tag;
}
lod->push_back(v);
}
break;
}
return false;
}
default: { return false; }
}
}
return true;
}
int VariableResponse::Parse(const ::grpc::ByteBuffer& byte_buffer) {
GrpcByteBufferSource source;
source.Init(byte_buffer);
GrpcByteBufferSourceWrapper r(&source);
return Parse(&r);
}
int VariableResponse::Parse(Source* source) {
::google::protobuf::io::ZeroCopyInputStream* input_stream =
source->contents();
::google::protobuf::io::CodedInputStream input(input_stream);
input.SetTotalBytesLimit(INT_MAX, INT_MAX);
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
WireType wt = GetTagWireType(p.first);
if (!p.second) {
if (tag != 0) {
return -1;
}
return 0;
}
switch (tag) {
case sendrecv::VariableMessage::kVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_varname(temp);
break;
}
case sendrecv::VariableMessage::kTypeFieldNumber: {
uint32_t v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_type(static_cast<::sendrecv::VarType>(v));
break;
}
case sendrecv::VariableMessage::kDataTypeFieldNumber: {
uint32_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) {
return tag;
}
meta_.set_data_type(static_cast<::sendrecv::VariableMessage_Type>(v));
break;
}
case sendrecv::VariableMessage::kDimsFieldNumber: {
// not packed
if (wt == WIRETYPE_VARINT) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
break;
}
// packed
if (wt == WIRETYPE_LENGTH_DELIMITED) {
int num_bytes = 0;
if (!input.ReadVarintSizeAsInt(&num_bytes)) {
return tag;
}
int start_pos = input.CurrentPosition();
while (input.CurrentPosition() - start_pos < num_bytes) {
uint64_t v;
if (!input.ReadVarint64(&v)) {
return tag;
}
meta_.add_dims(v);
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kLodLevelFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_lod_level(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kLodFieldNumber: {
int length = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &length)) {
return tag;
}
std::pair<::google::protobuf::io::CodedInputStream::Limit, int> p =
input.IncrementRecursionDepthAndPushLimit(length);
std::vector<int64_t> lod_data;
if (p.second < 0 || !ParseLodData(&input, &lod_data)) {
return tag;
}
if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
return false;
}
if (lod_data.size() == 0) {
break;
}
auto lod = meta_.add_lod();
for (uint32_t i = 0; i < lod_data.size(); i++) {
lod->add_lod_data(lod_data[i]);
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS || PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR || meta_.type() == sendrecv::LOD_TENSOR ||
meta_.type() == sendrecv::NCCL_ID) && meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "", meta_.varname() != "",
"meta info should be got first!"); "meta info should be got first!");
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (meta_.type() == sendrecv::NCCL_ID) { if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto* var = scope_->FindVar(meta_.varname()); auto* var = scope_->FindVar(meta_.varname());
if (var != nullptr) { if (var != nullptr) {
ncclUniqueId* id = var->GetMutable<ncclUniqueId>(); ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal, if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
num_bytes)) { num_bytes)) {
return tag; return false;
} }
} }
break; return true;
#else #else
PADDLE_THROW("Not compiled with CUDA!"); PADDLE_THROW("Not compiled with CUDA!");
return false;
#endif #endif
} }
framework::DDim dims = GetDims(meta_.dims()); framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) { if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0, PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
"lod info should be got first!"); if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
if (!CopyLodTensorData(&input, *dev_ctx_, dims, num_bytes)) { return false;
return tag;
} }
break; return true;
} }
if (meta_.type() == sendrecv::SELECTED_ROWS) { if (meta_.type() == sendrecv::SELECTED_ROWS) {
if (!CopySelectRowsTensorData(&input, *dev_ctx_, dims, num_bytes)) { if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
return tag; return false;
}
break;
}
return tag;
}
case sendrecv::VariableMessage::kRowsFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
meta_.varname() != "",
"meta info should be got first!");
int num_bytes = 0;
if (wt != WIRETYPE_LENGTH_DELIMITED ||
!ReadVarintSizeAsInt(&input, &num_bytes)) {
return tag;
}
if (!CopySelectRowsData(&input, *dev_ctx_, num_bytes)) {
return tag;
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
case sendrecv::VariableMessage::kProfileFieldNumber: {
uint64_t profiling = 0;
if (!input.ReadVarint64(&profiling)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling == platform::kEnableProfiler &&
!platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (profiling == platform::kDisableProfiler &&
platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
}
} }
return true;
} }
return 0; return true;
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -22,18 +22,35 @@ ...@@ -22,18 +22,35 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/distributed/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/distributed/bytebuffer_stream.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
// Source provides a way for a particular RPC implementation to provide
// received data to ParseFrom.
class Source {
public:
virtual ~Source() {}
// Return the stream that contains the data to be parsed.
// Note that this method might be invoked more than once if
// ParseFrom needs to fall back to a more expensive parsing method.
// Every call must return a stream pointing at the beginning of
// the serialized RecvTensorResponse.
//
// Note that a subsequent call to contents() invalidates previous
// results of contents().
//
// Ownership of the returned stream is retained by the Source and
// should not be deleted by the caller.
virtual ::google::protobuf::io::ZeroCopyInputStream* contents() = 0;
};
class VariableResponse { class VariableResponse {
public: public:
VariableResponse(const framework::Scope* scope, VariableResponse(const framework::Scope* scope,
...@@ -51,22 +68,19 @@ class VariableResponse { ...@@ -51,22 +68,19 @@ class VariableResponse {
} }
} }
// return: int Parse(Source* source, const sendrecv::VariableMessage& meta) {
// 0:ok. meta_ = meta;
// -1: unkown error. return Parse(source);
// other: number of error field. }
int Parse(Source* source);
// return: // return:
// 0:ok. // 0:ok.
// -1: unkown error. // -1: unkown error.
// other: number of error field. // other: number of error field.
int Parse(const ::grpc::ByteBuffer& byte_buffer); virtual int Parse(Source* source) = 0;
const framework::Scope& GetLocalScope() const { return *local_scope_; }
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline const framework::Scope& GetLocalScope() const { return *local_scope_; }
inline framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() const { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
...@@ -78,7 +92,11 @@ class VariableResponse { ...@@ -78,7 +92,11 @@ class VariableResponse {
return scope_->FindVar(meta_.varname()); return scope_->FindVar(meta_.varname());
} }
private: protected:
bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place,
void* dest, int64_t size);
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::DDim& dims, int length); const framework::DDim& dims, int length);
...@@ -90,12 +108,16 @@ class VariableResponse { ...@@ -90,12 +108,16 @@ class VariableResponse {
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::DDim& dims, int length); const framework::DDim& dims, int length);
private: bool ProcSerializedField(int tag,
::google::protobuf::io::CodedInputStream* input,
int64_t num_bytes);
protected:
const framework::Scope* scope_; const framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_; const platform::DeviceContext* dev_ctx_;
bool create_scope_ = false; bool create_scope_ = false;
framework::Scope* local_scope_ = nullptr; framework::Scope* local_scope_ = nullptr;
// only Skeleton
sendrecv::VariableMessage meta_; sendrecv::VariableMessage meta_;
}; };
......
...@@ -37,6 +37,7 @@ struct CBlas<float> { ...@@ -37,6 +37,7 @@ struct CBlas<float> {
libxsmm_sgemm(args...); libxsmm_sgemm(args...);
} }
#endif #endif
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
platform::dynload::cblas_saxpy(args...); platform::dynload::cblas_saxpy(args...);
...@@ -76,6 +77,7 @@ struct CBlas<double> { ...@@ -76,6 +77,7 @@ struct CBlas<double> {
libxsmm_dgemm(args...); libxsmm_dgemm(args...);
} }
#endif #endif
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
platform::dynload::cblas_daxpy(args...); platform::dynload::cblas_daxpy(args...);
...@@ -150,6 +152,7 @@ struct CBlas<double> { ...@@ -150,6 +152,7 @@ struct CBlas<double> {
} }
}; };
#endif #endif
template <> template <>
struct CBlas<platform::float16> { struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
...@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k, ...@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
return false; return false;
} }
template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, inline void GEMM_WARP(CBLAS_ORDER order, CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha,
int N, int K, T alpha, const T *A, const T *A, int lda, const T *B, int ldb, T beta, T *C,
const T *B, T beta, T *C) const { int ldc) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha, if (UseXSMM<T>(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) { beta)) {
// Note: SMM use ColMajor // Note: SMM use ColMajor
const char transa = 'N'; const char transa = 'N';
const char transb = 'N'; const char transb = 'N';
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc); &beta, C, &ldc);
} else { return;
}
#endif #endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
ldb, beta, C, ldc); #ifdef PADDLE_MKL_SPLIT_GEMM
#ifdef PADDLE_WITH_LIBXSMM constexpr int bs = 2;
if (M % bs == 0 && transA == CblasNoTrans && transB == CblasNoTrans) {
for (int off = 0; off < M; off += bs) {
CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bs, N, K, alpha,
A + off * lda, lda, B, ldb, beta, C + off * ldb, ldc);
}
return;
} }
#endif #endif
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A,
const T *B, T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
GEMM_WARP<T>(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
} }
template <> template <>
...@@ -222,7 +243,7 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -222,7 +243,7 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int N, int K, T alpha, const T *A, int N, int K, T alpha, const T *A,
int lda, const T *B, int ldb, int lda, const T *B, int ldb,
T beta, T *C, int ldc) const { T beta, T *C, int ldc) const {
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, GEMM_WARP<T>(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
lda, B, ldb, beta, C, ldc); lda, B, ldb, beta, C, ldc);
} }
......
...@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) { ...@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
} }
delete ctx; delete ctx;
} }
template <typename T>
void GemmWarpTest(int m, int n, int k, T alpha, T beta) {
paddle::framework::Tensor mat_a;
paddle::framework::Tensor mat_b;
paddle::framework::Tensor mat_c_ref;
paddle::framework::Tensor mat_c_mkl;
auto* cpu_place = new paddle::platform::CPUPlace();
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
T* CREF = mat_c_ref.mutable_data<T>({m, n}, *cpu_place);
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel());
for (int i = 0; i < mat_a.numel(); ++i) {
A[i] = static_cast<T>(i);
}
for (int i = 0; i < mat_b.numel(); ++i) {
B[i] = static_cast<T>(i + 1);
}
for (int i = 0; i < mat_c_ref.numel(); ++i) {
CREF[i] = static_cast<T>(i + 2);
CMKL[i] = CREF[i];
}
// this would call gemm_warp
paddle::platform::CPUDeviceContext context(*cpu_place);
GetBlas<T>(context).GEMM(CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B,
beta, CREF);
// lda,ldb,ldc follow RowMajor
int lda = k;
int ldb = n;
int ldc = n;
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
CblasNoTrans, m, n, k, alpha, A, lda,
B, ldb, beta, CMKL, ldc);
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
EXPECT_FLOAT_EQ(CREF[i], CMKL[i]);
}
}
TEST(math_function, gemm_warp) {
GemmWarpTest<float>(3, 2, 5, 1.f, 0.f);
GemmWarpTest<float>(3, 2, 5, 2.f, 1.f);
GemmWarpTest<float>(8, 5, 6, 1.f, 0.f);
GemmWarpTest<float>(8, 5, 6, 2.f, 1.f);
GemmWarpTest<double>(3, 2, 5, 1.0, 0.0);
GemmWarpTest<double>(3, 2, 5, 2.0, 1.0);
GemmWarpTest<double>(8, 5, 6, 1.0, 0.0);
GemmWarpTest<double>(8, 5, 6, 2.0, 1.0);
}
...@@ -98,7 +98,7 @@ The update equations are as follows: ...@@ -98,7 +98,7 @@ The update equations are as follows:
$$ $$
velocity = mu * velocity + gradient \\ velocity = mu * velocity + gradient \\
if (use\_nesterov): \\ if (use\_nesterov): \\
param = param - gradient * learning\_rate + mu * velocity * learning\_rate \\ param = param - (gradient + mu * velocity) * learning\_rate \\
else: \\ else: \\
param = param - learning\_rate * velocity. \\ param = param - learning\_rate * velocity. \\
$$ $$
......
...@@ -30,7 +30,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v, ...@@ -30,7 +30,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
T g_val = g[i]; T g_val = g[i];
T v_new = v[i] * mu + g_val; T v_new = v[i] * mu + g_val;
v_out[i] = v_new; v_out[i] = v_new;
p_out[i] = p[i] - (g_val - v_new * mu) * lr; p_out[i] = p[i] - (g_val + v_new * mu) * lr;
} }
} else { } else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
......
...@@ -46,7 +46,7 @@ class MomentumOpKernel : public framework::OpKernel<T> { ...@@ -46,7 +46,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
v_out = v * mu + g; v_out = v * mu + g;
if (use_nesterov) { if (use_nesterov) {
p_out = p - (g - v_out * mu) * lr[0]; p_out = p - (g + v_out * mu) * lr[0];
} else { } else {
p_out = p - lr[0] * v_out; p_out = p - lr[0] * v_out;
} }
......
...@@ -15,12 +15,13 @@ function(reader_library TARGET_NAME) ...@@ -15,12 +15,13 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE) PARENT_SCOPE)
endfunction() endfunction()
reader_library(open_files_op SRCS open_files_op.cc) cc_library(buffered_reader SRCS buffered_reader.cc DEPS reader simple_threadpool)
reader_library(open_files_op SRCS open_files_op.cc DEPS buffered_reader)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc) reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc DEPS buffered_reader)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc) reader_library(create_custom_reader_op SRCS create_custom_reader_op.cc)
reader_library(create_py_reader_op SRCS create_py_reader_op.cc) reader_library(create_py_reader_op SRCS create_py_reader_op.cc)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include <vector>
namespace paddle {
namespace operators {
namespace reader {
BufferedReader::~BufferedReader() { reader_->Shutdown(); }
BufferedReader::BufferedReader(
const std::shared_ptr<framework::ReaderBase> &reader,
const platform::Place &place, size_t buffer_size)
: framework::DecoratedReader(reader),
thread_pool_(1),
place_(place),
buffer_size_(buffer_size) {
cpu_buffer_.resize(buffer_size);
gpu_buffer_.resize(buffer_size);
ReadTillBufferFullAsync();
}
void BufferedReader::ReadTillBufferFullAsync() {
PADDLE_ENFORCE_EQ(position_.size(), 0U);
for (size_t i = 0; i < buffer_size_; ++i) {
ReadAsync(i);
}
}
void BufferedReader::ReadAsync(size_t i) {
position_.emplace(thread_pool_.enqueue([this, i]() -> size_t {
TensorVec &cpu = cpu_buffer_[i];
reader_->ReadNext(&cpu);
if (cpu.empty()) {
return -1UL;
}
if (platform::is_gpu_place(place_)) {
TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) {
framework::TensorCopySync(cpu[i], place_, &gpu[i]);
gpu[i].set_lod(cpu[i].lod());
}
}
return i;
}));
}
void BufferedReader::ShutdownImpl() {
reader_->Shutdown();
while (!position_.empty()) {
position_.pop();
}
prev_pos_ = -1UL;
}
void BufferedReader::StartImpl() {
reader_->Start();
ReadTillBufferFullAsync();
}
void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
if (position_.empty()) {
out->clear();
return;
}
size_t i = position_.front().get();
position_.pop();
if (i == -1UL) {
ReadNextImpl(out);
return;
}
*out = platform::is_gpu_place(place_) ? gpu_buffer_[i] : cpu_buffer_[i];
// Do not push current position into ReadAsync. Push the previous position
// Since all computation in fluid are async, change the data of
// current position may cause data error.
if (prev_pos_ != -1Ul) {
ReadAsync(prev_pos_);
}
prev_pos_ = i;
}
} // namespace reader
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <queue>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/reader.h"
namespace paddle {
namespace operators {
namespace reader {
class BufferedReader : public framework::DecoratedReader {
using TensorVec = std::vector<framework::LoDTensor>;
using VecFuture = std::future<TensorVec>;
public:
BufferedReader(const std::shared_ptr<framework::ReaderBase>& reader,
const platform::Place& place, size_t buffer_size);
~BufferedReader() override;
private:
void ReadTillBufferFullAsync();
void ReadAsync(size_t i);
protected:
void ShutdownImpl() override;
void StartImpl() override;
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
private:
ThreadPool thread_pool_;
platform::Place place_;
const size_t buffer_size_;
std::queue<std::future<size_t>> position_;
// The buffer for reading data.
// NOTE: the simplest way to implement buffered reader is do not use any
// buffer, just read async and create futures as buffer size. However, to
// malloc tensors every time is extremely slow. Here we store all data in
// buffers and prevent alloc every time.
std::vector<TensorVec> cpu_buffer_;
std::vector<TensorVec> gpu_buffer_;
size_t prev_pos_{-1UL};
};
} // namespace reader
} // namespace operators
} // namespace paddle
...@@ -12,83 +12,12 @@ ...@@ -12,83 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <thread> // NOLINT #include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
// 'Double buffer' means we shall maintain two batches of input data at the same
// time. So the kCacheSize shoul be at least 2.
static constexpr size_t kCacheSize = 3;
// There will be two bacthes out of the channel during training:
// 1. the one waiting to be sent to the channel
// 2. the one just be received from the channel, which is also being used by
// subsequent operators.
// So the channel size should be kChacheSize - 2
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(
const std::shared_ptr<ReaderBase>& reader,
platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
cpu_tensor_cache_.resize(kCacheSize);
gpu_tensor_cache_.resize(kCacheSize);
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
for (size_t i = 0; i < kCacheSize; ++i) {
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
}
}
#endif
StartPrefetcher();
}
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override;
~DoubleBufferReader() { EndPrefetcher(); }
private:
void ShutdownImpl() override {
EndPrefetcher();
reader_->Shutdown();
}
void StartImpl() override {
reader_->Start();
StartPrefetcher();
}
void StartPrefetcher() {
channel_ = new reader::BlockingQueue<size_t>(kChannelSize);
prefetcher_ = std::thread([this] { PrefetchThreadFunc(); });
}
void EndPrefetcher() {
channel_->Close();
if (prefetcher_.joinable()) {
prefetcher_.join();
}
delete channel_;
channel_ = nullptr;
}
void PrefetchThreadFunc();
std::thread prefetcher_;
reader::BlockingQueue<size_t>* channel_;
platform::Place place_;
std::vector<std::vector<framework::LoDTensor>> cpu_tensor_cache_;
std::vector<std::vector<framework::LoDTensor>> gpu_tensor_cache_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
};
class CreateDoubleBufferReaderOp : public framework::OperatorBase { class CreateDoubleBufferReaderOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
...@@ -118,8 +47,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { ...@@ -118,8 +47,8 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
place = platform::CUDAPlace(static_cast<int>(num)); place = platform::CUDAPlace(static_cast<int>(num));
} }
out->Reset(framework::MakeDecoratedReader<DoubleBufferReader>( out->Reset(framework::MakeDecoratedReader<BufferedReader>(underlying_reader,
underlying_reader, place)); place, 2));
} }
}; };
...@@ -146,51 +75,6 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { ...@@ -146,51 +75,6 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
} }
}; };
void DoubleBufferReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
size_t cached_tensor_id;
if (channel_->Receive(&cached_tensor_id)) {
if (platform::is_gpu_place(place_)) {
*out = gpu_tensor_cache_[cached_tensor_id];
} else {
// CPU place
*out = cpu_tensor_cache_[cached_tensor_id];
}
} else {
out->clear();
}
}
void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
size_t cached_tensor_id = 0;
while (true) {
auto& cpu_batch = cpu_tensor_cache_[cached_tensor_id];
reader_->ReadNext(&cpu_batch);
if (cpu_batch.empty()) {
// The underlying reader have no next data.
break;
}
if (platform::is_gpu_place(place_)) {
auto& gpu_batch = gpu_tensor_cache_[cached_tensor_id];
gpu_batch.resize(cpu_batch.size());
for (size_t i = 0; i < cpu_batch.size(); ++i) {
// TODO(fengjiayi): Use asynchronous TensorCopy instead
framework::TensorCopySync(cpu_batch[i], place_, &gpu_batch[i]);
gpu_batch[i].set_lod(cpu_batch[i].lod());
}
}
if (!channel_->Send(cached_tensor_id)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread will terminate.";
break;
}
++cached_tensor_id;
cached_tensor_id %= kCacheSize;
}
channel_->Close();
VLOG(5) << "Prefetch thread terminates.";
}
} // namespace reader } // namespace reader
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -33,6 +33,8 @@ class PyReader : public framework::FileReader { ...@@ -33,6 +33,8 @@ class PyReader : public framework::FileReader {
if (!success) out->clear(); if (!success) out->clear();
} }
~PyReader() { queue_->Close(); }
void Shutdown() override { queue_->Close(); } void Shutdown() override { queue_->Close(); }
void Start() override { queue_->ReOpen(); } void Start() override { queue_->ReOpen(); }
......
...@@ -33,11 +33,14 @@ class RecordIOFileReader : public framework::FileReader { ...@@ -33,11 +33,14 @@ class RecordIOFileReader : public framework::FileReader {
protected: protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override { void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
std::unique_ptr<std::lock_guard<std::mutex>> guard;
if (ThreadSafe) { if (ThreadSafe) {
std::lock_guard<std::mutex> guard(*mutex_); guard.reset(new std::lock_guard<std::mutex>(*mutex_));
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_); }
} else {
*out = framework::ReadFromRecordIO(&scanner_, dev_ctx_); bool ok = framework::ReadFromRecordIO(&scanner_, dev_ctx_, out);
if (!ok) {
out->clear();
} }
} }
......
...@@ -48,9 +48,9 @@ class ShuffleReader : public framework::DecoratedReader { ...@@ -48,9 +48,9 @@ class ShuffleReader : public framework::DecoratedReader {
private: private:
void ShutdownImpl() override { void ShutdownImpl() override {
reader_->Shutdown();
buffer_.clear(); buffer_.clear();
iteration_pos_ = 0; iteration_pos_ = 0;
reader_->Shutdown();
} }
void StartImpl() override { void StartImpl() override {
......
...@@ -12,150 +12,200 @@ ...@@ -12,150 +12,200 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include <stdexcept>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "ThreadPool.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/operators/reader/buffered_reader.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace reader { namespace reader {
class MultiFileReader : public framework::ReaderBase { class IReaderContainer {
public:
virtual ~IReaderContainer() {}
virtual void AppendReader(
std::unique_ptr<framework::ReaderBase>&& readers) = 0;
virtual void Stop() = 0;
virtual void Start() = 0;
virtual void ReadNext(std::vector<framework::LoDTensor>* out) = 0;
};
class OrderedReaderContainer : public IReaderContainer {
public: public:
MultiFileReader(const std::vector<std::string>& file_names, size_t thread_num, void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
size_t buffer_size) pending_.emplace(std::move(reader));
: buffer_size_(buffer_size) { }
readers_.reserve(file_names.size());
for (const std::string& f_name : file_names) { void Stop() override {
readers_.emplace_back(CreateReaderByFileName(f_name)); while (!pending_.empty()) {
MoveFrontPendingToDone();
} }
prefetchers_.resize(thread_num);
StartNewScheduler();
} }
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override; void Start() override { std::swap(done_, pending_); }
~MultiFileReader() { EndScheduler(); } void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!pending_.empty()) {
pending_.front()->ReadNext(out);
if (out->empty()) {
MoveFrontPendingToDone();
ReadNext(out);
}
} else {
out->clear();
}
}
private: private:
void ShutdownImpl() override { EndScheduler(); } void MoveFrontPendingToDone() {
pending_.front()->Shutdown();
void StartImpl() override { StartNewScheduler(); } pending_.front()->Start();
done_.emplace(move(pending_.front()));
void StartNewScheduler(); pending_.pop();
void EndScheduler(); }
void ScheduleThreadFunc();
void PrefetchThreadFunc(size_t reader_idx, size_t thread_idx); std::queue<std::unique_ptr<framework::ReaderBase>> pending_;
std::queue<std::unique_ptr<framework::ReaderBase>> done_;
std::vector<std::unique_ptr<framework::ReaderBase>> readers_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
size_t buffer_size_;
reader::BlockingQueue<size_t>* waiting_reader_idx_;
reader::BlockingQueue<size_t>* available_thread_idx_;
reader::BlockingQueue<std::vector<framework::LoDTensor>>* buffer_;
}; };
void MultiFileReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { class PreemptiveReaderContainer : public IReaderContainer {
if (!buffer_->Receive(out)) { using ReaderList = std::list<std::unique_ptr<framework::ReaderBase>>;
out->clear();
struct FutureItem {
std::vector<framework::LoDTensor> data_;
ReaderList::iterator reader_it_;
std::exception_ptr exception_;
};
using FutureList = std::list<std::future<FutureItem>>;
public:
explicit PreemptiveReaderContainer(size_t thread_num) : pool_(thread_num) {}
void Stop() override {
if (!pending_.empty()) {
for (auto& reader : pending_) {
reader->Shutdown();
} }
} for (auto& fu : futures_) {
fu.wait();
void MultiFileReader::StartNewScheduler() { }
size_t thread_num = prefetchers_.size(); futures_.clear();
waiting_reader_idx_ = new reader::BlockingQueue<size_t>(readers_.size()); for (auto& reader : pending_) {
available_thread_idx_ = new reader::BlockingQueue<size_t>(thread_num); reader->Start();
buffer_ = new reader::BlockingQueue<std::vector<framework::LoDTensor>>( done_.emplace_back(std::move(reader));
buffer_size_);
for (size_t i = 0; i < readers_.size(); ++i) {
waiting_reader_idx_->Send(i);
}
waiting_reader_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(i);
}
scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}
void MultiFileReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_reader_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_reader_idx_;
}
void MultiFileReader::ScheduleThreadFunc() {
VLOG(5) << "MultiFileReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t reader_idx;
if (waiting_reader_idx_->Receive(&reader_idx)) {
// Still have files to read. Start a new prefetch thread.
prefetcher = std::thread([this, reader_idx, thread_idx] {
PrefetchThreadFunc(reader_idx, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
} }
pending_.clear();
bool timeout;
complete_queue_.PopAll(1000, &timeout);
PADDLE_ENFORCE(!timeout);
} }
} }
// If users invoke Shutdown() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler void Start() override {
// to release their resource. So a check is needed before scheduler ends. for (auto& reader : done_) {
for (auto& p : prefetchers_) { AppendReader(std::move(reader));
if (p.joinable()) {
p.join();
} }
done_.clear();
} }
VLOG(5) << "MultiFileReader schedule thread terminates.";
}
void MultiFileReader::PrefetchThreadFunc(size_t reader_idx, size_t thread_idx) { void ReadNext(std::vector<framework::LoDTensor>* out) override {
VLOG(5) << "The prefetch thread of file idx '" << reader_idx << "' starts."; if (!pending_.empty()) {
std::unique_ptr<framework::ReaderBase>& reader = readers_[reader_idx]; auto future_it = complete_queue_.Pop();
while (true) { FutureItem item = future_it->get();
std::vector<framework::LoDTensor> ins; if (item.exception_) {
reader->ReadNext(&ins); for (auto it = futures_.begin(); it != futures_.end(); ++it) {
if (ins.empty()) { if (it != future_it) {
reader->Shutdown(); it->wait(); // Wait all other threads complete.
reader->Start(); }
break; }
std::rethrow_exception(item.exception_);
} else if (item.data_.empty()) { // reader done.
done_.emplace_back(std::move(*item.reader_it_));
pending_.erase(item.reader_it_);
futures_.erase(future_it);
ReadNext(out);
} else {
*out = item.data_;
// continue read async
ReadAsync(item.reader_it_, &future_it);
} }
} else {
out->clear();
}
}
private:
void AppendReader(std::unique_ptr<framework::ReaderBase>&& reader) override {
pending_.emplace_back(std::move(reader));
auto reader_it = pending_.end();
--reader_it;
futures_.emplace_back();
auto future_it = futures_.end();
--future_it;
ReadAsync(reader_it, &future_it);
}
void ReadAsync(const ReaderList::iterator& reader_it,
FutureList::iterator* future_it_ptr) {
auto& future_it = *future_it_ptr;
*future_it = pool_.enqueue([reader_it, future_it, this] {
try { try {
buffer_->Send(std::move(ins)); FutureItem item;
} catch (paddle::platform::EnforceNotMet e) { item.reader_it_ = reader_it;
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch " (*reader_it)->ReadNext(&item.data_);
"thread of file idx '" if (item.data_.empty()) {
<< reader_idx << "' will terminate."; (*reader_it)->Shutdown();
break; (*reader_it)->Start();
} }
complete_queue_.Push(future_it);
return item;
} catch (...) {
FutureItem item;
item.exception_ = std::current_exception();
complete_queue_.Push(future_it);
return item;
}
});
} }
if (!available_thread_idx_->Send(thread_idx)) { FutureList futures_;
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. " ThreadPool pool_;
"Fail to send thread_idx."; framework::BlockingQueue<FutureList::iterator> complete_queue_;
std::list<std::unique_ptr<framework::ReaderBase>> pending_;
std::list<std::unique_ptr<framework::ReaderBase>> done_;
};
class MultiFileReader : public framework::ReaderBase {
public:
MultiFileReader(const std::vector<std::string>& file_names,
std::unique_ptr<IReaderContainer>&& container)
: container_(std::move(container)) {
for (auto& fn : file_names) {
container_->AppendReader(CreateReaderByFileName(fn));
} }
VLOG(5) << "The prefetch thread of file idx '" << reader_idx }
<< "' terminates.";
} ~MultiFileReader() { container_->Stop(); }
protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
container_->ReadNext(out);
}
void ShutdownImpl() override { container_->Stop(); }
void StartImpl() override { container_->Start(); }
private:
std::unique_ptr<IReaderContainer> container_;
};
class OpenFilesOp : public framework::OperatorBase { class OpenFilesOp : public framework::OperatorBase {
public: public:
...@@ -173,13 +223,27 @@ class OpenFilesOp : public framework::OperatorBase { ...@@ -173,13 +223,27 @@ class OpenFilesOp : public framework::OperatorBase {
"shape concat's length."); "shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names"); const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!"); PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num"); bool is_test = Attr<bool>("is_test");
const size_t buffer_size = Attr<int>("buffer_size");
auto* out = scope.FindVar(Output("Out")) auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>(); ->template GetMutable<framework::ReaderHolder>();
out->Reset( std::unique_ptr<IReaderContainer> container;
std::make_shared<MultiFileReader>(file_names, thread_num, buffer_size));
if (is_test) {
container.reset(new OrderedReaderContainer());
} else {
container.reset(new PreemptiveReaderContainer(
static_cast<size_t>(Attr<int>("thread_num"))));
}
std::shared_ptr<framework::ReaderBase> reader(
new MultiFileReader(file_names, std::move(container)));
auto buffer_size = Attr<int>("buffer_size");
if (buffer_size > 1) {
reader = framework::MakeDecoratedReader<BufferedReader>(
reader, platform::CPUPlace(), buffer_size);
}
out->Reset(reader);
} }
}; };
...@@ -187,9 +251,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -187,9 +251,7 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
protected: protected:
void Apply() override { void Apply() override {
AddAttr<std::vector<std::string>>("file_names", "Files to be read."); AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.") AddAttr<bool>("is_test", "Used for testing data.").SetDefault(false);
.GreaterThan(0);
AddAttr<int>("buffer_size", "The size of prefetch buffer.").GreaterThan(0);
AddComment(R"DOC( AddComment(R"DOC(
OpenFiles Operator OpenFiles Operator
...@@ -197,6 +259,11 @@ class OpenFilesOpMaker : public FileReaderMakerBase { ...@@ -197,6 +259,11 @@ class OpenFilesOpMaker : public FileReaderMakerBase {
An OpenFilesOp creates a MultiFileReader, which is able to An OpenFilesOp creates a MultiFileReader, which is able to
read data multi-threaded from multiple files. read data multi-threaded from multiple files.
)DOC"); )DOC");
AddAttr<int>("thread_num",
"The maximal concurrent prefetch thread number. Used only "
"when is_test = False");
AddAttr<int>("buffer_size", "The reading buffer of these files.")
.GreaterThan(0);
} }
}; };
......
...@@ -24,6 +24,9 @@ ...@@ -24,6 +24,9 @@
#include "paddle/fluid/operators/tensorrt_engine_op.h" #include "paddle/fluid/operators/tensorrt_engine_op.h"
namespace paddle { namespace paddle {
DEFINE_int32(tensorrt_engine_batch_size, 1, "the batch_size of TensorRT");
namespace operators { namespace operators {
using inference::Singleton; using inference::Singleton;
...@@ -52,7 +55,6 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) { ...@@ -52,7 +55,6 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
"TensorRT' tensor input requires at least 2 dimensions"); "TensorRT' tensor input requires at least 2 dimensions");
PADDLE_ENFORCE_LE(shape.size(), 4UL, PADDLE_ENFORCE_LE(shape.size(), 4UL,
"TensorRT' tensor input requires at most 4 dimensions"); "TensorRT' tensor input requires at most 4 dimensions");
switch (shape.size()) { switch (shape.size()) {
case 2: case 2:
return nvinfer1::Dims2(shape[0], shape[1]); return nvinfer1::Dims2(shape[0], shape[1]);
...@@ -90,27 +92,36 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare( ...@@ -90,27 +92,36 @@ void TensorRTEngineKernel<DeviceContext, T>::Prepare(
engine->InitNetwork(); engine->InitNetwork();
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc); framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
VLOG(4) << "parsed var size " << block.AllVars().size();
// Add inputs // Add inputs
VLOG(4) << "declare inputs"; VLOG(4) << "declare inputs";
for (auto &input : context.Inputs("Xs")) { for (auto &input : context.Inputs("Xs")) {
VLOG(4) << "declare input " << input; VLOG(4) << "declare input " << input;
auto *var = block.FindVar(input); auto *var = block.FindVar(input);
// TensorRT engine need to create parameters. The parameter's description
// should be set in
PADDLE_ENFORCE(var, "no variable called %s", input);
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input"); "TensorRT engine only takes LoDTensor as input");
auto shape = var->GetShape(); auto shape = var->GetShape();
// For the special batch_size placeholder -1, drop it and pass the real
// shape of data.
// TODO(Superjomn) fix this with batch broadcast, or it can't handle
// variational batch size.
if (shape[0] == -1) {
shape[0] = FLAGS_tensorrt_engine_batch_size;
}
engine->DeclareInput( engine->DeclareInput(
input, FluidDataType2TRT( input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()), var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var->GetShape())); Vec2TRT_Dims(shape));
} }
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock( inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block_desc, parameters, context.scope(), engine); block_desc, parameters, context.scope(), engine);
// Add outputs // Add outputs
VLOG(4) << "declare outputs";
for (auto &output : context.Outputs("Ys")) { for (auto &output : context.Outputs("Ys")) {
VLOG(4) << "declare output " << output;
engine->DeclareOutput(output); engine->DeclareOutput(output);
} }
...@@ -151,4 +162,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -151,4 +162,7 @@ REGISTER_OP_CPU_KERNEL(
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int>, ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int>,
ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::TensorRTEngineKernel<paddle::platform::CPUDeviceContext, int64_t>);
// A trick to compile with the needed TensorRT op converter.
USE_TRT_CONVERTER(mul)
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
...@@ -24,6 +24,9 @@ ...@@ -24,6 +24,9 @@
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle { namespace paddle {
DECLARE_int32(tensorrt_engine_batch_size);
namespace operators { namespace operators {
using inference::Singleton; using inference::Singleton;
...@@ -53,7 +56,6 @@ template <typename DeviceContext, typename T> ...@@ -53,7 +56,6 @@ template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> { class TensorRTEngineKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
VLOG(4) << "TensorRTEngineKernel executing";
auto engine_name = context.Attr<std::string>("engine_uniq_key"); auto engine_name = context.Attr<std::string>("engine_uniq_key");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) { if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context); Prepare(context);
...@@ -61,11 +63,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -61,11 +63,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name); auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
auto input_names = context.op().Inputs("Xs"); auto input_names = context.op().Inputs("Xs");
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs"); PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
// Try to determine a batch_size PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size,
auto& tensor0 = inference::analysis::GetFromScope<framework::LoDTensor>( context.Attr<int>("max_batch"));
context.scope(), input_names.front());
int batch_size = tensor0.dims()[0];
PADDLE_ENFORCE_LE(batch_size, context.Attr<int>("max_batch"));
// Convert input tensor from fluid to engine. // Convert input tensor from fluid to engine.
for (const auto& x : context.Inputs("Xs")) { for (const auto& x : context.Inputs("Xs")) {
...@@ -81,8 +80,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -81,8 +80,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
} }
} }
// Execute the engine. // Execute the engine.
PADDLE_ENFORCE_GT(batch_size, 0); PADDLE_ENFORCE_GT(FLAGS_tensorrt_engine_batch_size, 0);
engine->Execute(batch_size); engine->Execute(FLAGS_tensorrt_engine_batch_size);
// Convert output tensor from engine to fluid // Convert output tensor from engine to fluid
for (const auto& y : context.Outputs("Ys")) { for (const auto& y : context.Outputs("Ys")) {
// convert output and copy to fluid. // convert output and copy to fluid.
...@@ -94,18 +93,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> { ...@@ -94,18 +93,21 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto* fluid_v = context.scope().FindVar(y); auto* fluid_v = context.scope().FindVar(y);
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y); PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>(); auto* fluid_t = fluid_v->GetMutable<framework::LoDTensor>();
fluid_t->Resize(framework::make_ddim(ddim));
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims); auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
if (platform::is_cpu_place(fluid_t->place())) { fluid_t->Resize(framework::make_ddim(ddim));
// TODO(Superjomn) find some way to determine which device to output the
// tensor.
// if (platform::is_cpu_place(fluid_t->place())) {
// TODO(Superjomn) change this float to dtype size. // TODO(Superjomn) change this float to dtype size.
engine->GetOutputInCPU( engine->GetOutputInCPU(y,
y, fluid_t->mutable_data<float>(platform::CPUPlace()), fluid_t->mutable_data<float>(platform::CPUPlace()),
size * sizeof(float));
} else {
engine->GetOutputInGPU(
y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
size * sizeof(float)); size * sizeof(float));
} //} else {
// engine->GetOutputInGPU(
// y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
// size * sizeof(float));
//}
} }
cudaStreamSynchronize(*engine->stream()); cudaStreamSynchronize(*engine->stream());
......
...@@ -28,6 +28,7 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream) ...@@ -28,6 +28,7 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
Scanner::Scanner(const std::string &filename) Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) { : stream_(new std::ifstream(filename)), parser_(*stream_) {
PADDLE_ENFORCE(static_cast<bool>(*stream_), "Cannot open file %s", filename);
Reset(); Reset();
} }
......
...@@ -333,8 +333,7 @@ function assert_api_not_changed() { ...@@ -333,8 +333,7 @@ function assert_api_not_changed() {
python ${PADDLE_ROOT}/tools/diff_api.py ${PADDLE_ROOT}/paddle/fluid/API.spec new.spec python ${PADDLE_ROOT}/tools/diff_api.py ${PADDLE_ROOT}/paddle/fluid/API.spec new.spec
deactivate deactivate
# Use git diff --name-only HEAD^ may not get file changes for update commits in one PR API_CHANGE=`git diff --name-only upstream/develop | grep "paddle/fluid/API.spec" || true`
API_CHANGE=`echo $CHANGED_FILES | grep "paddle/fluid/API.spec" || true`
echo "checking API.spec change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}" echo "checking API.spec change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}"
if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then
# TODO: curl -H 'Authorization: token ${TOKEN}' # TODO: curl -H 'Authorization: token ${TOKEN}'
...@@ -600,11 +599,11 @@ function main() { ...@@ -600,11 +599,11 @@ function main() {
cicheck) cicheck)
cmake_gen ${PYTHON_ABI:-""} cmake_gen ${PYTHON_ABI:-""}
build build
assert_api_not_changed
run_test run_test
gen_capi_package gen_capi_package
gen_fluid_inference_lib gen_fluid_inference_lib
test_fluid_inference_lib test_fluid_inference_lib
assert_api_not_changed
;; ;;
*) *)
print_usage print_usage
......
...@@ -25,9 +25,6 @@ import numpy ...@@ -25,9 +25,6 @@ import numpy
__all__ = [ __all__ = [
'split_lod_tensor', 'split_lod_tensor',
'merge_lod_tensor', 'merge_lod_tensor',
'BlockGuard',
'BlockGuardWithCompletion',
'WhileGuard',
'While', 'While',
'Switch', 'Switch',
'lod_rank_table', 'lod_rank_table',
......
...@@ -12,14 +12,18 @@ ...@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import multiprocessing
import threading
from .. import core from ..data_feeder import DataFeeder
from ..framework import convert_np_dtype_to_dtype_, default_main_program, default_startup_program, Program
from ..unique_name import generate as unique_name
from control_flow import BlockGuard from control_flow import BlockGuard
from ..layer_helper import LayerHelper from layer_function_generator import templatedoc
from .. import core
from ..executor import global_scope from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
default_startup_program, program_guard, Program
from ..layer_helper import LayerHelper
from ..unique_name import generate as unique_name
__all__ = [ __all__ = [
'data', 'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch', 'data', 'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
...@@ -445,7 +449,12 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True): ...@@ -445,7 +449,12 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
return monkey_patch_reader_methods(main_prog_var) return monkey_patch_reader_methods(main_prog_var)
def py_reader(capacity, shapes, dtypes, lod_levels=None): def py_reader(capacity,
shapes,
dtypes,
lod_levels=None,
name=None,
use_double_buffer=True):
""" """
Create a reader and blocking queue for data feeding in Python Create a reader and blocking queue for data feeding in Python
...@@ -458,10 +467,13 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None): ...@@ -458,10 +467,13 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None):
using `close()` method when unused. using `close()` method when unused.
Args: Args:
use_double_buffer(bool): Whether use double buffer or not.
capacity(int): The maximum capacity of the BlockingQueue. capacity(int): The maximum capacity of the BlockingQueue.
shapes(list): List of tuples which declaring data shapes. shapes(list|tuple): List of tuples which declaring data shapes.
dtypes(list): List of strs which declaring data type. dtypes(list|tuple): List of strs which declaring data type.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list|tuple): List of ints which declaring data lod_level.
name(basestring): The prefix Python queue name and Reader name. None will
be generated automatically.
Returns: Returns:
tuple(Variable, BlockingQueue): tuple(Variable, BlockingQueue):
...@@ -502,15 +514,23 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None): ...@@ -502,15 +514,23 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None):
if lod_levels is None: if lod_levels is None:
lod_levels = [0] * len(shapes) lod_levels = [0] * len(shapes)
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue') queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_py_reader')
double_buffer_name = unique_name('double_buffer')
else:
queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"])
double_buffer_name = "_".join([name, "double_buffer"])
var = global_scope().var(queue_name) var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes) feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=unique_name('create_py_reader')) startup_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op( startup_blk.append_op(
type='create_py_reader', type='create_py_reader',
inputs={'blocking_queue': queue_name}, inputs={'blocking_queue': [queue_name]},
outputs={'Out': [startup_var]}, outputs={'Out': [startup_var]},
attrs={ attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
...@@ -524,17 +544,96 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None): ...@@ -524,17 +544,96 @@ def py_reader(capacity, shapes, dtypes, lod_levels=None):
main_prog_var = _copy_reader_var_(default_main_program().current_block(), main_prog_var = _copy_reader_var_(default_main_program().current_block(),
startup_var) startup_var)
return monkey_patch_reader_methods(main_prog_var), feed_queue reader = monkey_patch_reader_methods(main_prog_var)
if use_double_buffer:
double_buffer_reader = double_buffer(reader, name=double_buffer_name)
# we return a double buffer reader. However, the reset method comes from
# py_reader.
double_buffer_reader.reset = reader.reset
reader = double_buffer_reader
# monkey patch py_reader special methods
reader.queue = feed_queue
current_reset_method = reader.reset
reader.thread = None
reader.tensor_provider = None
reader.exited = False
def start_provide_thread(func):
def __provider_thread__():
for tensors in func():
array = core.LoDTensorArray()
for item in tensors:
if not isinstance(item, core.LoDTensor):
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
array.append(item)
if reader.exited:
break
feed_queue.push(array)
if reader.exited:
break
feed_queue.close()
reader.thread = threading.Thread(target=__provider_thread__)
reader.thread.start()
def __set_tensor_provider__(func):
reader.tensor_provider = func
def __set_paddle_reader__(paddle_reader):
with program_guard(Program(), Program()):
feed_list = []
counter = 0
for dtype, shape, lod_level in zip(dtypes, shapes, lod_levels):
name = str(counter)
feed_list.append(
data(
name=name,
dtype=dtype,
shape=shape,
lod_level=lod_level))
counter += 1
feeder = DataFeeder(feed_list=feed_list, place=core.CPUPlace())
paddle_reader = feeder.decorate_reader(
paddle_reader, multi_devices=False)
def __tensor_provider__():
for slots in paddle_reader():
yield [slots[str(idx)] for idx in xrange(counter)]
__set_tensor_provider__(__tensor_provider__)
def __reset__():
current_reset_method()
if reader.thread is not None and reader.tensor_provider is not None:
reader.exited = True
reader.thread.join()
reader.exited = False
def __start__():
start_provide_thread(reader.tensor_provider)
reader.reset = __reset__
reader.decorate_tensor_provider = __set_tensor_provider__
reader.decorate_paddle_reader = __set_paddle_reader__
reader.start = __start__
return reader
def open_files(filenames, def open_files(filenames,
shapes, shapes,
lod_levels, lod_levels,
dtypes, dtypes,
thread_num=1, thread_num=None,
buffer_size=None, buffer_size=None,
pass_num=1, pass_num=1,
for_parallel=True): is_test=None):
""" """
Open files Open files
...@@ -547,14 +646,14 @@ def open_files(filenames, ...@@ -547,14 +646,14 @@ def open_files(filenames,
shapes(list): List of tuples which declaring data shapes. shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level. lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type. dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number. thread_num(None): The number of thread to read files.
buffer_size(int|None): The size of prefetch buffer. If it is setted None, Default: min(len(filenames), cpu_number).
buffer size will be thread_num * 3. buffer_size(None): The buffer size of reader. Default: 3 * thread_num
Default: None
pass_num(int): Number of passes to run. pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run is_test(bool|None): Whether `open_files` used for testing or not. If it
subsequent operators in parallel. is used for testing, the order of data generated is same as the file
Default: True order. Otherwise, it is not guaranteed the order of data is same
between every epoch. [Default: False].
Returns: Returns:
Variable: A Reader Variable via which we can get file data. Variable: A Reader Variable via which we can get file data.
...@@ -566,15 +665,21 @@ def open_files(filenames, ...@@ -566,15 +665,21 @@ def open_files(filenames,
'./data2.recordio'], './data2.recordio'],
shapes=[(3,224,224), (1)], shapes=[(3,224,224), (1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'])
thread_num=2,
buffer_size=2)
# Via the reader, we can use 'read_file' layer to get data: # Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader) image, label = fluid.layers.io.read_file(reader)
""" """
if thread_num is None:
thread_num = min(len(filenames), multiprocessing.cpu_count())
else:
thread_num = int(thread_num)
if buffer_size is None: if buffer_size is None:
buffer_size = thread_num * 3 buffer_size = 3 * thread_num
else:
buffer_size = int(buffer_size)
if isinstance(filenames, basestring): if isinstance(filenames, basestring):
filenames = [filenames] filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes] dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
...@@ -588,17 +693,18 @@ def open_files(filenames, ...@@ -588,17 +693,18 @@ def open_files(filenames,
multi_file_reader_name = unique_name('multi_file_reader') multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block() startup_blk = default_startup_program().current_block()
startup_reader = startup_blk.create_var(name=multi_file_reader_name) startup_reader = startup_blk.create_var(name=multi_file_reader_name)
startup_blk.append_op( attrs = {
type='open_files',
outputs={'Out': [startup_reader]},
attrs={
'shape_concat': shape_concat, 'shape_concat': shape_concat,
'lod_levels': lod_levels, 'lod_levels': lod_levels,
'ranks': ranks, 'ranks': ranks,
'file_names': filenames, 'file_names': filenames,
'thread_num': thread_num, 'thread_num': thread_num,
'buffer_size': buffer_size 'buffer_size': buffer_size
}) }
if is_test is not None:
attrs['is_test'] = is_test
startup_blk.append_op(
type='open_files', outputs={'Out': [startup_reader]}, attrs=attrs)
startup_reader.desc.set_dtypes(dtypes) startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True startup_reader.persistable = True
......
...@@ -114,23 +114,13 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -114,23 +114,13 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
prediction = network(image, is_infer=True) prediction = network(image, is_infer=True)
auc_out=fluid.layers.auc(input=prediction, label=label) auc_out=fluid.layers.auc(input=prediction, label=label)
""" """
warnings.warn(
"This interface is not recommended, fluid.layers.auc compute the auc at every minibatch, \
but can not aggregate them and get the pass AUC, because pass \
auc can not be averaged with weighted from the minibatch auc value. \
Please use fluid.metrics.Auc, it can compute the auc value via Python natively, \
which can get every minibatch and every pass auc value.", Warning)
helper = LayerHelper("auc", **locals()) helper = LayerHelper("auc", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype) auc_out = helper.create_tmp_variable(dtype="float64")
topk_indices = helper.create_tmp_variable(dtype="int64")
topk_out, topk_indices = nn.topk(input, k=k)
auc_out = helper.create_tmp_variable(dtype="float32")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
tp = helper.create_global_variable(persistable=True) tp = helper.create_global_variable(persistable=True, dtype='int64')
tn = helper.create_global_variable(persistable=True) tn = helper.create_global_variable(persistable=True, dtype='int64')
fp = helper.create_global_variable(persistable=True) fp = helper.create_global_variable(persistable=True, dtype='int64')
fn = helper.create_global_variable(persistable=True) fn = helper.create_global_variable(persistable=True, dtype='int64')
for var in [tp, tn, fp, fn]: for var in [tp, tn, fp, fn]:
helper.set_variable_initializer( helper.set_variable_initializer(
var, Constant( var, Constant(
...@@ -139,8 +129,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -139,8 +129,7 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
helper.append_op( helper.append_op(
type="auc", type="auc",
inputs={ inputs={
"Out": [topk_out], "Predict": [input],
"Indices": [topk_indices],
"Label": [label], "Label": [label],
"TP": [tp], "TP": [tp],
"TN": [tn], "TN": [tn],
...@@ -156,4 +145,4 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -156,4 +145,4 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
"FPOut": [fp], "FPOut": [fp],
"FNOut": [fn] "FNOut": [fn]
}) })
return auc_out return auc_out, [tp, tn, fp, fn]
...@@ -166,7 +166,8 @@ def fc(input, ...@@ -166,7 +166,8 @@ def fc(input,
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer. parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units. of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str, default None): Activation to be applied to the output of this layer. act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase. is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
......
...@@ -591,7 +591,7 @@ class Auc(MetricBase): ...@@ -591,7 +591,7 @@ class Auc(MetricBase):
for i in range(self._num_thresholds - 2)] for i in range(self._num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count # calculate TP, FN, TN, FP count
for idx_thresh, thresh in enumerate(thresholds): for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0 tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels): for i, lbl in enumerate(labels):
......
...@@ -324,7 +324,7 @@ class MomentumOptimizer(Optimizer): ...@@ -324,7 +324,7 @@ class MomentumOptimizer(Optimizer):
& if (use\_nesterov): & if (use\_nesterov):
&\quad param = param - gradient * learning\_rate + mu * velocity * learning\_rate &\quad param = param - (gradient + mu * velocity) * learning\_rate
& else: & else:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import paddle
import paddle.dataset.mnist as mnist
import paddle.fluid as fluid
import paddle.v2
def network(is_train):
reader = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader")
img, label = fluid.layers.read_file(reader)
hidden = img
for i in xrange(2):
hidden = fluid.layers.fc(input=hidden, size=100, act='tanh')
hidden = fluid.layers.dropout(
hidden, dropout_prob=0.5, is_test=not is_train)
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
return fluid.layers.mean(loss), reader
def main():
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
loss, train_reader = network(True)
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(loss)
test_prog = fluid.Program()
test_startup = fluid.Program()
with fluid.program_guard(test_prog, test_startup):
with fluid.unique_name.guard():
test_loss, test_reader = network(False)
fluid.Executor(fluid.CUDAPlace(0)).run(startup_prog)
fluid.Executor(fluid.CUDAPlace(0)).run(test_startup)
trainer = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=train_prog)
tester = fluid.ParallelExecutor(
use_cuda=True, share_vars_from=trainer, main_program=test_prog)
train_reader.decorate_paddle_reader(
paddle.v2.reader.shuffle(
paddle.batch(mnist.train(), 512), buf_size=8192))
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
for epoch_id in xrange(10):
train_reader.start()
try:
while True:
print 'train_loss', numpy.array(
trainer.run(fetch_list=[loss.name]))
except fluid.core.EOFException:
print 'End of epoch', epoch_id
train_reader.reset()
test_reader.start()
try:
while True:
print 'test loss', numpy.array(
tester.run(fetch_list=[test_loss.name]))
except fluid.core.EOFException:
print 'End of testing'
test_reader.reset()
if __name__ == '__main__':
main()
...@@ -31,7 +31,10 @@ def load_vocab(filename): ...@@ -31,7 +31,10 @@ def load_vocab(filename):
# load word dict with paddle inner function # load word dict with paddle inner function
word_dict = load_vocab(sys.argv[1]) if len(sys.argv) == 1:
word_dict = paddle.dataset.imdb.word_dict()
else:
word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict) word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict) print "Dict dim = ", len(word_dict)
......
...@@ -41,16 +41,14 @@ def network_cfg(is_train, pass_num=100): ...@@ -41,16 +41,14 @@ def network_cfg(is_train, pass_num=100):
pass_num=pass_num, pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]], shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0], lod_levels=[1, 0],
dtypes=['int64', 'int64'], dtypes=['int64', 'int64'])
thread_num=1)
test_file_obj = fluid.layers.open_files( test_file_obj = fluid.layers.open_files(
filenames=TEST_FILES, filenames=TEST_FILES,
pass_num=1, pass_num=1,
shapes=[[-1, 1], [-1, 1]], shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0], lod_levels=[1, 0],
dtypes=['int64', 'int64'], dtypes=['int64', 'int64'])
thread_num=1)
if is_train: if is_train:
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000) file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000)
......
...@@ -48,6 +48,7 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op) ...@@ -48,6 +48,7 @@ list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_dist_train) list(REMOVE_ITEM TEST_OPS test_dist_train)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf) list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed) list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
list(REMOVE_ITEM TEST_OPS test_dist_se_resnext)
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
...@@ -60,3 +61,4 @@ if(WITH_DISTRIBUTE) ...@@ -60,3 +61,4 @@ if(WITH_DISTRIBUTE)
endif() endif()
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL) py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext SERIAL)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import time
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import sys
import signal
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class SE_ResNeXt():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
cardinality = 32
reduction_ratio = 16
depth = [3, 4, 6, 3]
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
elif layers == 101:
cardinality = 32
reduction_ratio = 16
depth = [3, 4, 23, 3]
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
elif layers == 152:
cardinality = 64
reduction_ratio = 16
depth = [3, 8, 36, 3]
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=3,
stride=2,
act='relu')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
conv = self.conv_bn_layer(
input=conv,
num_filters=128,
filter_size=3,
stride=1,
act='relu')
conv = fluid.layers.pool2d(
input=conv, pool_size=3, pool_stride=2, pool_padding=1, \
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
cardinality=cardinality,
reduction_ratio=reduction_ratio)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
drop = fluid.layers.dropout(x=pool, dropout_prob=0.2)
stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)
out = fluid.layers.fc(input=drop, size=class_dim, act='softmax')
return out
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
filter_size = 1
return self.conv_bn_layer(input, ch_out, filter_size, stride)
else:
return input
def bottleneck_block(self, input, num_filters, stride, cardinality,
reduction_ratio):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
groups=cardinality,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 2, filter_size=1, act=None)
scale = self.squeeze_excitation(
input=conv2,
num_channels=num_filters * 2,
reduction_ratio=reduction_ratio)
short = self.shortcut(input, num_filters * 2, stride)
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) / 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def squeeze_excitation(self, input, num_channels, reduction_ratio):
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(input=pool,
size=num_channels / reduction_ratio,
act='relu')
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(input=squeeze,
size=num_channels,
act='sigmoid')
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def get_model(batch_size):
# Input data
image = fluid.layers.fill_constant(
shape=[batch_size, 3, 224, 224], dtype='float32', value=0.0)
label = fluid.layers.fill_constant(
shape=[batch_size, 1], dtype='int64', value=0.0)
# Train program
model = SE_ResNeXt(layers=50)
out = model.net(input=image, class_dim=102)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
# Evaluator
test_program = fluid.default_main_program().clone(for_test=True)
# Optimization
total_images = 6149 # flowers
epochs = [30, 60, 90]
step = int(total_images / batch_size + 1)
bd = [step * e for e in epochs]
base_lr = 0.1
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(avg_cost)
# Reader
train_reader = paddle.batch(
paddle.dataset.flowers.train(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.flowers.test(), batch_size=batch_size)
return test_program, avg_cost, train_reader, test_reader, acc_top1, out
def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id=trainer_id,
program=main_program,
pservers=pserver_endpoints,
trainers=trainers)
return t
class DistSeResneXt2x2:
def run_pserver(self, pserver_endpoints, trainers, current_endpoint,
trainer_id):
get_model(batch_size=2)
t = get_transpiler(trainer_id,
fluid.default_main_program(), pserver_endpoints,
trainers)
pserver_prog = t.get_pserver_program(current_endpoint)
startup_prog = t.get_startup_program(current_endpoint, pserver_prog)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
exe.run(pserver_prog)
def _wait_ps_ready(self, pid):
retry_times = 20
while True:
assert retry_times >= 0, "wait ps ready failed"
time.sleep(3)
print("waiting ps ready: ", pid)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
retry_times -= 1
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model(
batch_size=20)
if is_dist:
t = get_transpiler(trainer_id,
fluid.default_main_program(), endpoints,
trainers)
trainer_prog = t.get_trainer_program()
else:
trainer_prog = fluid.default_main_program()
startup_exe = fluid.Executor(place)
startup_exe.run(fluid.default_startup_program())
strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1
strategy.allow_op_delay = False
exe = fluid.ParallelExecutor(
True,
loss_name=avg_cost.name,
exec_strategy=strategy,
num_trainers=trainers,
trainer_id=trainer_id)
feed_var_list = [
var for var in trainer_prog.global_block().vars.itervalues()
if var.is_data
]
feeder = fluid.DataFeeder(feed_var_list, place)
reader_generator = train_reader()
first_loss, = exe.run(fetch_list=[avg_cost.name])
print(first_loss)
for i in xrange(5):
loss, = exe.run(fetch_list=[avg_cost.name])
last_loss, = exe.run(fetch_list=[avg_cost.name])
print(last_loss)
def main(role="pserver",
endpoints="127.0.0.1:9123",
trainer_id=0,
current_endpoint="127.0.0.1:9123",
trainers=1,
is_dist=True):
model = DistSeResneXt2x2()
if role == "pserver":
model.run_pserver(endpoints, trainers, current_endpoint, trainer_id)
else:
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist)
if __name__ == "__main__":
if len(sys.argv) != 7:
print(
"Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]"
)
role = sys.argv[1]
endpoints = sys.argv[2]
trainer_id = int(sys.argv[3])
current_endpoint = sys.argv[4]
trainers = int(sys.argv[5])
is_dist = True if sys.argv[6] == "TRUE" else False
main(
role=role,
endpoints=endpoints,
trainer_id=trainer_id,
current_endpoint=current_endpoint,
trainers=trainers,
is_dist=is_dist)
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.fluid import metrics
class TestAucOp(OpTest): class TestAucOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "auc" self.op_type = "auc"
pred = np.random.random((128, 2)).astype("float32") pred = np.random.random((128, 2)).astype("float32")
indices = np.random.randint(0, 2, (128, 2))
labels = np.random.randint(0, 2, (128, 1)) labels = np.random.randint(0, 2, (128, 1))
num_thresholds = 200 num_thresholds = 200
tp = np.zeros((num_thresholds, )).astype("int64") tp = np.zeros((num_thresholds, )).astype("int64")
...@@ -30,8 +30,7 @@ class TestAucOp(OpTest): ...@@ -30,8 +30,7 @@ class TestAucOp(OpTest):
fn = np.zeros((num_thresholds, )).astype("int64") fn = np.zeros((num_thresholds, )).astype("int64")
self.inputs = { self.inputs = {
'Out': pred, 'Predict': pred,
'Indices': indices,
'Label': labels, 'Label': labels,
'TP': tp, 'TP': tp,
'TN': tn, 'TN': tn,
...@@ -39,57 +38,18 @@ class TestAucOp(OpTest): ...@@ -39,57 +38,18 @@ class TestAucOp(OpTest):
'FN': fn 'FN': fn
} }
self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds}
# NOTE: sklearn use a different way to generate thresholds
# which will cause the result differs slightly:
# from sklearn.metrics import roc_curve, auc
# fpr, tpr, thresholds = roc_curve(labels, pred)
# auc_value = auc(fpr, tpr)
# we caculate AUC again using numpy for testing
kepsilon = 1e-7 # to account for floating point imprecisions
thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
for i in range(num_thresholds - 2)]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
# caculate TP, FN, TN, FP count python_auc = metrics.Auc(name="auc",
tp_list = np.ndarray((num_thresholds, )) curve='ROC',
fn_list = np.ndarray((num_thresholds, )) num_thresholds=num_thresholds)
tn_list = np.ndarray((num_thresholds, )) python_auc.update(pred, labels)
fp_list = np.ndarray((num_thresholds, ))
for idx_thresh, thresh in enumerate(thresholds):
tp, fn, tn, fp = 0, 0, 0, 0
for i, lbl in enumerate(labels):
if lbl:
if pred[i, 0] >= thresh:
tp += 1
else:
fn += 1
else:
if pred[i, 0] >= thresh:
fp += 1
else:
tn += 1
tp_list[idx_thresh] = tp
fn_list[idx_thresh] = fn
tn_list[idx_thresh] = tn
fp_list[idx_thresh] = fp
epsilon = 1e-6
tpr = (tp_list.astype("float32") + epsilon) / (
tp_list + fn_list + epsilon)
fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon)
rec = (tp_list.astype("float32") + epsilon) / (
tp_list + fp_list + epsilon)
x = fpr[:num_thresholds - 1] - fpr[1:]
y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0
auc_value = np.sum(x * y)
self.outputs = { self.outputs = {
'AUC': auc_value, 'AUC': python_auc.eval(),
'TPOut': tp_list, 'TPOut': python_auc.tp_list,
'FNOut': fn_list, 'FNOut': python_auc.fn_list,
'TNOut': tn_list, 'TNOut': python_auc.tn_list,
'FPOut': fp_list 'FPOut': python_auc.fp_list
} }
def test_check_output(self): def test_check_output(self):
......
...@@ -142,8 +142,7 @@ class TestDataBalance(unittest.TestCase): ...@@ -142,8 +142,7 @@ class TestDataBalance(unittest.TestCase):
filenames=[self.lod_data_file_name], filenames=[self.lod_data_file_name],
shapes=[[-1, 3], [-1, 1]], shapes=[[-1, 3], [-1, 1]],
lod_levels=[1, 0], lod_levels=[1, 0],
dtypes=['float32', 'int32'], dtypes=['float32', 'int32'])
thread_num=1)
ins, label = fluid.layers.read_file(data_reader) ins, label = fluid.layers.read_file(data_reader)
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
...@@ -156,7 +155,7 @@ class TestDataBalance(unittest.TestCase): ...@@ -156,7 +155,7 @@ class TestDataBalance(unittest.TestCase):
main_program=main_prog, main_program=main_prog,
build_strategy=build_strategy) build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size): if parallel_exe.device_count > self.batch_size:
print("WARNING: Unittest TestDataBalance skipped. \ print("WARNING: Unittest TestDataBalance skipped. \
For the result is not correct when device count \ For the result is not correct when device count \
is larger than batch size.") is larger than batch size.")
...@@ -190,3 +189,7 @@ class TestDataBalance(unittest.TestCase): ...@@ -190,3 +189,7 @@ class TestDataBalance(unittest.TestCase):
def test_all(self): def test_all(self):
self.main() self.main()
self.main_lod() self.main_lod()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import time
import math
import unittest
import os
import signal
import subprocess
class TestDistSeResneXt2x2(unittest.TestCase):
def setUp(self):
self._trainers = 2
self._pservers = 2
self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
self._python_interp = "python"
def start_pserver(self):
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps0_cmd = "%s dist_se_resnext.py pserver %s 0 %s %d TRUE" % \
(self._python_interp, self._ps_endpoints, ps0_ep, self._trainers)
ps1_cmd = "%s dist_se_resnext.py pserver %s 0 %s %d TRUE" % \
(self._python_interp, self._ps_endpoints, ps1_ep, self._trainers)
ps0_proc = subprocess.Popen(
ps0_cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
ps1_proc = subprocess.Popen(
ps1_cmd.split(" "), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return ps0_proc, ps1_proc
def _wait_ps_ready(self, pid):
retry_times = 20
while True:
assert retry_times >= 0, "wait ps ready failed"
time.sleep(3)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
retry_times -= 1
def non_test_with_place(self):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs = {
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH"),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH"),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15"
}
# Run local to get a base line
env_local = {"CUDA_VISIBLE_DEVICES": "0"}
env_local.update(required_envs)
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
(self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
local_proc = subprocess.Popen(
local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local)
local_proc.wait()
local_ret = local_proc.stdout.read()
# Run dist train to compare with local results
ps0, ps1 = self.start_pserver()
self._wait_ps_ready(ps0.pid)
self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr0_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d TRUE" % \
(self._python_interp, self._ps_endpoints, ps0_ep, self._trainers)
tr1_cmd = "%s dist_se_resnext.py trainer %s 1 %s %d TRUE" % \
(self._python_interp, self._ps_endpoints, ps1_ep, self._trainers)
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
env0.update(required_envs)
env1.update(required_envs)
FNULL = open(os.devnull, 'w')
tr0_proc = subprocess.Popen(
tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0)
tr1_proc = subprocess.Popen(
tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1)
tr0_proc.wait()
tr1_proc.wait()
loss_data0 = tr0_proc.stdout.read()
lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
dist_last_loss = eval(lines[1].replace(" ", ","))[0]
local_lines = local_ret.split("\n")
local_first_loss = eval(local_lines[0])[0]
local_last_loss = eval(local_lines[1])[0]
self.assertAlmostEqual(local_first_loss, dist_first_loss)
self.assertAlmostEqual(local_last_loss, dist_last_loss)
# check tr0_out
# FIXME: ensure the server process is killed
# replace with ps0.terminate()
os.kill(ps0.pid, signal.SIGKILL)
os.kill(ps1.pid, signal.SIGKILL)
FNULL.close()
if __name__ == "__main__":
unittest.main()
...@@ -102,7 +102,7 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestLearningRateDecay(unittest.TestCase):
exe.run(startup_prog) exe.run(startup_prog)
# fluid.memory_optimize(main_prog) fluid.memory_optimize(main_prog)
for step in range(10): for step in range(10):
lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr]) lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr])
......
...@@ -39,7 +39,7 @@ class TestMomentumOp1(OpTest): ...@@ -39,7 +39,7 @@ class TestMomentumOp1(OpTest):
velocity_out = mu * velocity + grad velocity_out = mu * velocity + grad
if use_nesterov: if use_nesterov:
param_out = param - grad * learning_rate + \ param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate velocity_out * mu * learning_rate
else: else:
param_out = param - learning_rate * velocity_out param_out = param - learning_rate * velocity_out
...@@ -75,7 +75,7 @@ class TestMomentumOp2(OpTest): ...@@ -75,7 +75,7 @@ class TestMomentumOp2(OpTest):
velocity_out = mu * velocity + grad velocity_out = mu * velocity + grad
if use_nesterov: if use_nesterov:
param_out = param - grad * learning_rate + \ param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate velocity_out * mu * learning_rate
else: else:
param_out = param - learning_rate * velocity_out param_out = param - learning_rate * velocity_out
......
...@@ -39,17 +39,17 @@ class TestMultipleReader(unittest.TestCase): ...@@ -39,17 +39,17 @@ class TestMultipleReader(unittest.TestCase):
copyfile('./mnist_0.recordio', './mnist_1.recordio') copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio') copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num): def main(self, is_test=False):
file_list = [ file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio' './mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
] ]
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files( data_files = fluid.layers.open_files(
filenames=file_list, filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)], shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'],
is_test=is_test)
img, label = fluid.layers.read_file(data_files) img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
...@@ -71,6 +71,9 @@ class TestMultipleReader(unittest.TestCase): ...@@ -71,6 +71,9 @@ class TestMultipleReader(unittest.TestCase):
self.assertEqual(batch_count, self.num_batch * 3) self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self): def test_main(self):
self.main(thread_num=3) # thread number equals to file number self.main(is_test=False)
self.main(thread_num=10) # thread number is larger than file number self.main(is_test=True)
self.main(thread_num=2) # thread number is less than file number
if __name__ == '__main__':
unittest.main()
...@@ -33,9 +33,7 @@ def simple_fc_net(use_feed): ...@@ -33,9 +33,7 @@ def simple_fc_net(use_feed):
filenames=[MNIST_RECORDIO_FILE], filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'])
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader) reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
hidden = img hidden = img
...@@ -61,9 +59,7 @@ def fc_with_batchnorm(use_feed): ...@@ -61,9 +59,7 @@ def fc_with_batchnorm(use_feed):
filenames=[MNIST_RECORDIO_FILE], filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]], shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64'], dtypes=['float32', 'int64'])
thread_num=1,
for_parallel=True)
reader = fluid.layers.io.double_buffer(reader) reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader) img, label = fluid.layers.read_file(reader)
...@@ -102,6 +98,16 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -102,6 +98,16 @@ class TestMNIST(TestParallelExecutorBase):
fluid.recordio_writer.convert_reader_to_recordio_file( fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder) MNIST_RECORDIO_FILE, reader, feeder)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
# simple_fc
def check_simple_fc_convergence(self, use_cuda, use_reduce=False): def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
...@@ -109,8 +115,8 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -109,8 +115,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) simple_fc_net, use_cuda=use_cuda, allow_op_delay=True)
img = np.zeros(shape=[32, 784], dtype='float32') img, label = self._init_data()
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence( self.check_network_convergence(
simple_fc_net, simple_fc_net,
feed_dict={"image": img, feed_dict={"image": img,
...@@ -118,6 +124,37 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -118,6 +124,37 @@ class TestMNIST(TestParallelExecutorBase):
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=use_reduce) use_reduce=use_reduce)
def check_simple_fc_convergence_with_Reduce(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence(
simple_fc_net, use_cuda=use_cuda, use_reduce=True)
self.check_network_convergence(
simple_fc_net,
use_cuda=use_cuda,
allow_op_delay=True,
use_reduce=True)
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
simple_fc_net,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
simple_fc_net,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=True)
for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_fc(self): def test_simple_fc(self):
# use_cuda # use_cuda
self.check_simple_fc_convergence(True) self.check_simple_fc_convergence(True)
...@@ -125,14 +162,15 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -125,14 +162,15 @@ class TestMNIST(TestParallelExecutorBase):
def test_simple_fc_with_new_strategy(self): def test_simple_fc_with_new_strategy(self):
# use_cuda, use_reduce # use_cuda, use_reduce
self.check_simple_fc_convergence(True, True) self.check_simple_fc_convergence_with_Reduce(True)
self.check_simple_fc_convergence(False, True) self.check_simple_fc_convergence_with_Reduce(False)
def check_simple_fc_parallel_accuracy(self, use_cuda, use_reduce=False): def check_simple_fc_parallel_accuracy(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') img, label = self._init_data(random=False)
single_first_loss, single_last_loss = self.check_network_convergence( single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net, method=simple_fc_net,
seed=1000, seed=1000,
...@@ -146,8 +184,7 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -146,8 +184,7 @@ class TestMNIST(TestParallelExecutorBase):
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_parallel_executor=True, use_parallel_executor=True)
use_reduce=use_reduce)
for p_f in parallel_first_loss: for p_f in parallel_first_loss:
self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6) self.assertAlmostEquals(p_f, single_first_loss[0], delta=1e-6)
...@@ -158,32 +195,53 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -158,32 +195,53 @@ class TestMNIST(TestParallelExecutorBase):
self.check_simple_fc_parallel_accuracy(True) self.check_simple_fc_parallel_accuracy(True)
self.check_simple_fc_parallel_accuracy(False) self.check_simple_fc_parallel_accuracy(False)
def test_simple_fc_parallel_accuracy_with_new_strategy(self): def check_batchnorm_fc_convergence(self, use_cuda):
# use_cuda, use_reduce
self.check_simple_fc_parallel_accuracy(True, True)
self.check_simple_fc_parallel_accuracy(False, True)
def check_batchnorm_fc_convergence(self, use_cuda, use_reduce=False):
if use_cuda and not core.is_compiled_with_cuda(): if use_cuda and not core.is_compiled_with_cuda():
return return
self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda) self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda)
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64') img, label = self._init_data()
self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda)
def check_batchnorm_fc_convergence_use_reduce(self, use_cuda):
if use_cuda and not core.is_compiled_with_cuda():
return
self.check_network_convergence( self.check_network_convergence(
fc_with_batchnorm, use_cuda=use_cuda, use_reduce=True)
img, label = self._init_data()
all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm, fc_with_batchnorm,
feed_dict={"image": img, feed_dict={"image": img,
"label": label}, "label": label},
use_cuda=use_cuda, use_cuda=use_cuda,
use_reduce=use_reduce) use_reduce=False)
reduce_first_loss, reduce_last_loss = self.check_network_convergence(
fc_with_batchnorm,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
use_reduce=True)
for loss in zip(all_reduce_first_loss, reduce_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-4)
def test_batchnorm_fc(self): def test_batchnorm_fc(self):
self.check_batchnorm_fc_convergence(True) self.check_batchnorm_fc_convergence(True)
self.check_batchnorm_fc_convergence(False) self.check_batchnorm_fc_convergence(False)
def test_batchnorm_fc_with_new_strategy(self): def test_batchnorm_fc_with_new_strategy(self):
# use_cuda, use_reduce self.check_batchnorm_fc_convergence_use_reduce(True)
self.check_batchnorm_fc_convergence(True, True) self.check_batchnorm_fc_convergence_use_reduce(False)
self.check_batchnorm_fc_convergence(False, True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -45,12 +45,12 @@ class TestPyReader(unittest.TestCase): ...@@ -45,12 +45,12 @@ class TestPyReader(unittest.TestCase):
) else fluid.CPUPlace() ) else fluid.CPUPlace()
executor = fluid.Executor(place) executor = fluid.Executor(place)
data_file, feed_queue = fluid.layers.py_reader( data_file = fluid.layers.py_reader(
capacity=self.capacity, capacity=self.capacity,
dtypes=self.dtypes, dtypes=self.dtypes,
lod_levels=self.lod_levels, lod_levels=self.lod_levels,
shapes=self.shapes) shapes=self.shapes)
feed_queue = data_file.queue
read_out_data = fluid.layers.read_file(data_file) read_out_data = fluid.layers.read_file(data_file)
self.inputs = [] self.inputs = []
......
...@@ -52,11 +52,13 @@ def simple_fc_net(in_size, ...@@ -52,11 +52,13 @@ def simple_fc_net(in_size,
batch_size, batch_size,
queue_capacity, queue_capacity,
use_double_buffer=False): use_double_buffer=False):
reader, feed_queue = fluid.layers.py_reader( reader = fluid.layers.py_reader(
capacity=queue_capacity, capacity=queue_capacity,
shapes=[[-1, in_size], [-1, 1]], shapes=[[-1, in_size], [-1, 1]],
lod_levels=[0, 0], lod_levels=[0, 0],
dtypes=['float32', 'int64']) dtypes=['float32', 'int64'],
use_double_buffer=False)
feed_queue = reader.queue
reader = fluid.layers.batch(reader, batch_size=batch_size) reader = fluid.layers.batch(reader, batch_size=batch_size)
if use_double_buffer: if use_double_buffer:
reader = fluid.layers.double_buffer(reader) reader = fluid.layers.double_buffer(reader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册