提交 821cb9c0 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into update_version_via_git_branch

......@@ -173,6 +173,7 @@ are transformed into offsets of elements/words as follows:
## Slicing of LoD Tensors
When we use the above 2-level LoD Tensor as the input to a nested-RNN, we need to retrieve certain sequences. Here we define the sequence identified by branch <i,j,...> as the **<i,j,...>-slice**.
For example, the <2>-slice of above example is
......@@ -189,3 +190,22 @@ and the <2,0>-slice of above slice is
10 12
||
```
## Length Representation vs Offset Representation
The offset representation is an implementation-oriented decision and it makes understanding the idea behind LoDTensor difficult.
Hence, we encapsulate this implementation detail in C++ and expose the original length representation in our Python API.
Specifically, we call this length representation `recursive_sequence_lengths` and users can use the following code to set or get the `recursive_sequence_lengths` of a LoDTensor in Python:
```Python
# length representation of lod called recursive_sequence_lengths
recursive_seq_lens = [[3, 1, 2], [2, 2, 1, 3, 1, 2]]
# Create a LoDTensor that has the above recursive_sequence_lengths info.
# This recursive_sequence_lengths will be converted to an offset representation of LoD in the C++ implementation under the hood.
tensor = fluid.LoDTensor(lod)
# Set/Change the recursive_sequence_lengths info of LoDTensor
tensor.set_recursive_sequence_lengths([[3, 1, 2]])
# Get the recursive_sequence_lengths info of a LoDTensor (the offset-based LoD representation stored in C++ will be converted
# back to length-based recursive_sequence_lengths), new_recursive_seq_lens = [[3, 1, 2]]
new_recursive_seq_lens = tensor.recursive_sequence_lengths()
```
## 堆内存分析和优化
# 堆内存分析和优化
计算机程序都可能有内存泄漏的风险。**内存泄漏**一般是由于程序在堆(heap)上分配了内存而没有释放,随着程序的运行占用的内存越来越大,一方面会影响程序的稳定性,可能让运行速度越来越慢,或者造成oom,甚至会影响运行程序的机器的稳定性,造成宕机。
......@@ -20,11 +20,11 @@ Paddle也提供了基于gperftool的[CPU性能分析教程](https://github.com/P
对于堆内存的分析,主要用到thread-caching malloc和heap-profiling using tcmalloc。
## 使用流程
#### 环境
## 环境
本教程基于paddle提供的Docker开发环境paddlepaddle/paddle:latest-dev,基于Ubuntu 16.04.4 LTS环境。
#### 使用流程
## 使用流程
- 安装google-perftools
......
# 如何使用timeline工具做性能分析
1. 在训练的主循环外加上`with profiler.profiler(...)`。运行之后,代码会在`/tmp/profile`目录下生成一个profile的记录文件。
**提示:**
请不要在timeline记录信息时运行太多次迭代,因为timeline中的记录数量和迭代次数是成正比的。
```python
with profiler.profiler('All', 'total', '/tmp/profile') as prof:
for pass_id in range(pass_num):
for batch_id, data in enumerate(train_reader()):
exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[])
...
```
1. 运行`python paddle/tools/timeline.py`来处理`/tmp/profile`,这个程序默认会生成一个`/tmp/timeline`文件,你也可以用命令行参数来修改这个路径,请参考[timeline.py](https://github.com/PaddlePaddle/Paddle/blob/develop/tools/timeline.py)
1. 打开chrome浏览器,访问<chrome://tracing/>,用`load`按钮来加载生成的`timeline`文件。
![chrome tracing](./tracing.jpeg)
1. 结果如下图所示,可以放到来查看timetime的细节信息。
![chrome timeline](./timeline.jpeg)
......@@ -19,6 +19,9 @@ endif(APPLE)
set(inference_deps paddle_inference_api paddle_fluid_api)
if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine)
endif()
function(inference_api_test TARGET_NAME)
if (WITH_TESTING)
......@@ -50,7 +53,15 @@ cc_test(test_paddle_inference_api
inference_api_test(test_paddle_inference_api_impl
ARGS test_word2vec test_image_classification)
if (WITH_ANAKIN AND WITH_TESTING) # only needed in CI
if(WITH_GPU AND TENSORRT_FOUND)
cc_library(paddle_inference_tensorrt_subgraph_engine
SRCS paddle_inference_api_tensorrt_subgraph_engine.cc
DEPS paddle_inference_api analysis tensorrt_engine paddle_inference_api paddle_fluid_api)
inference_api_test(test_paddle_inference_api_tensorrt_subgraph_engine ARGS test_word2vec)
endif()
if (WITH_ANAKIN) # only needed in CI
# Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's,
# so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to
# compile the libinference_anakin_api.a and compile with anakin.so.
......@@ -60,10 +71,12 @@ if (WITH_ANAKIN AND WITH_TESTING) # only needed in CI
target_compile_options(inference_anakin_api_shared BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
target_link_libraries(inference_anakin_api anakin anakin_saber_common)
target_link_libraries(inference_anakin_api_shared anakin anakin_saber_common)
cc_test(inference_anakin_test SRCS paddle_inference_api_anakin_engine_tester.cc
if (WITH_TESTING)
cc_test(inference_anakin_test SRCS paddle_inference_api_anakin_engine_tester.cc
ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin
DEPS inference_anakin_api)
target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS})
endif(WITH_TESTING)
endif()
if(WITH_TESTING)
......
......@@ -15,6 +15,11 @@
inference_api_test(simple_on_word2vec ARGS test_word2vec)
option(WITH_INFERENCE_DEMO "Compile with Inference demo" OFF)
if(NOT WITH_INFERENCE_DEMO)
return()
endif()
set(DEMO_INSTALL_DIR "${PADDLE_BINARY_DIR}/inference_demo")
set(URL_ROOT http://paddlemodels.bj.bcebos.com/inference-vis-demos%2F)
......
......@@ -73,12 +73,12 @@ struct PaddleTensor {
};
enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
// TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
// kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
};
/*
......@@ -130,6 +130,11 @@ struct AnakinConfig : public PaddlePredictor::Config {
int max_batch_size{-1};
};
struct TensorRTConfig : public NativeConfig {
// Determine whether a subgraph will be executed by TRT.
int min_subgraph_size{1};
};
// A factory to help create different predictors.
//
// FOR EXTENSION DEVELOPER:
......
......@@ -89,6 +89,7 @@ bool NativePaddlePredictor::Init(
LOG(ERROR) << "fail to load inference model.";
return false;
}
ctx_ = executor_->Prepare(*inference_program_, 0);
executor_->CreateVariables(
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
......@@ -119,6 +120,7 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
return false;
}
for (size_t i = 0; i < feed_target_names_.size(); ++i) {
VLOG(4) << "setting " << i << "-th target";
feed_targets[feed_target_names_[i]] = &feeds[i];
}
// get fetch variable
......@@ -130,14 +132,16 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
}
// Run the inference program
// if share variables, we need not create variables
VLOG(4) << "Run prepared context";
executor_->RunPreparedContext(
ctx_.get(),
sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
&feed_targets,
&fetch_targets,
false /* don't create variable eatch time */);
VLOG(4) << "Finish prepared context";
if (!GetFetch(fetchs, output_data)) {
LOG(ERROR) << "fail to get fetchs";
LOG(ERROR) << "fail to get fetches";
return false;
}
VLOG(3) << "predict cost: " << timer.toc() << "ms";
......
......@@ -44,7 +44,7 @@ class NativePaddlePredictor : public PaddlePredictor {
~NativePaddlePredictor() override;
private:
protected:
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
std::vector<framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<framework::LoDTensor> &fetchs,
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/contrib/inference/paddle_inference_api.h"
#include "paddle/contrib/inference/paddle_inference_api_impl.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
using inference::analysis::Argument;
using inference::Singleton;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
class TensorRTSubgraphPredictor : public NativePaddlePredictor {
public:
explicit TensorRTSubgraphPredictor(const TensorRTConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
// Analyze inference_program
Argument argument;
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "transformed program:\n"
<< argument.transformed_program_desc->SerializeAsString();
VLOG(5) << "to prepare executor";
*inference_program_->Proto() = *argument.transformed_program_desc;
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
executor_->CreateVariables(
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames();
fetch_target_names_ = inference_program_->GetFetchTargetNames();
return true;
}
private:
TensorRTConfig config_;
};
template <>
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<TensorRTConfig, PaddleEngineKind::kAutoMixedTensorRT>(
const TensorRTConfig& config) {
VLOG(3) << "create TensorRTSubgraphPredictor";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
config.fraction_of_gpu_memory,
0.f,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]");
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
}
std::unique_ptr<PaddlePredictor> predictor(
new TensorRTSubgraphPredictor(config));
if (!dynamic_cast<TensorRTSubgraphPredictor*>(predictor.get())
->Init(nullptr)) {
return nullptr;
}
return std::move(predictor);
}
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/contrib/inference/paddle_inference_api.h"
namespace paddle {
DEFINE_string(dirname, "", "Directory of the inference model.");
void Main(bool use_gpu) {
//# 1. Create PaddlePredictor with a config.
TensorRTConfig config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model";
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
auto predictor =
CreatePaddlePredictor<TensorRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config);
for (int batch_id = 0; batch_id < 3; batch_id++) {
//# 2. Prepare input.
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor{.name = "",
.shape = std::vector<int>({4, 1}),
.data = PaddleBuf(data, sizeof(data)),
.dtype = PaddleDType::INT64};
// For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> slots(4, tensor);
//# 3. Run
std::vector<PaddleTensor> outputs;
CHECK(predictor->Run(slots, &outputs));
//# 4. Get output.
ASSERT_EQ(outputs.size(), 1UL);
LOG(INFO) << "output buffer size: " << outputs.front().data.length();
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << static_cast<float*>(outputs.front().data.data())[i];
}
}
}
TEST(paddle_inference_api_tensorrt_subgraph_engine, main) { Main(true); }
} // namespace paddle
\ No newline at end of file
......@@ -713,6 +713,10 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
} else if (var->IsType<LoDTensorArray>()) {
const LoDTensorArray& arr = var->Get<LoDTensorArray>();
PADDLE_ENFORCE(arr.size() > 0);
t = &(arr[0]);
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
......
......@@ -253,6 +253,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
t->set_lod(lod_tensors[j].lod());
}
}
for (auto &p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
ParallelExecutor::~ParallelExecutor() {
......
......@@ -28,9 +28,10 @@ endif()
if(WITH_TESTING)
# both tests/book and analysis depends the models that generated by python/paddle/fluid/tests/book
add_subdirectory(tests/book)
add_subdirectory(analysis)
endif()
add_subdirectory(analysis)
if (TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor init)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc
tensorrt_subgraph_pass.cc
dfg_graphviz_draw_pass.cc
DEPS framework_proto)
tensorrt_subgraph_pass.cc
tensorrt_subgraph_node_mark_pass.cc
analyzer.cc
helper.cc
DEPS framework_proto proto_desc)
cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
function (inference_analysis_test TARGET)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(WITH_TESTING)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
DEPS analysis
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --fraction_of_gpu_memory_to_use=0.5)
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING)
endfunction(inference_analysis_test)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
......@@ -28,5 +32,7 @@ inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
#inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
"Enable subgraph to TensorRT engine for acceleration");
DEFINE_string(inference_analysis_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs.");
class DfgPassManagerImpl final : public DfgPassManager {
public:
DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs.
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) {
auto trt_teller = [](const Node* node) {
if (!node->IsFunction()) return false;
return static_cast<const Function*>(node)->func_type() == "mul";
};
AddPass("tensorrt-subgraph-marker",
new TensorRTSubgraphNodeMarkPass(trt_teller));
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
}
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
LOG(INFO) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
}
// Add the graphviz debuger pass if the parent pass has one.
void AddGraphvizDebugerPass(Pass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) {
LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]";
Register(debuger_pass->repr(), debuger_pass);
}
}
};
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll();
PADDLE_ENFORCE(x->Finalize());
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains Analyzer, an class that exposed as a library that analyze
* and optimize
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to
* control whether
* an process is applied on the program.
*
* The processes are called Passes in analysis, the Passes are placed in a
* pipeline, the first
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to
* a data flow
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow
* graph to a
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes
* which take a
* node or data flow graph as input.
*
* The Analyzer can be used in two methods, the first is a executable file which
* can be used to
* pre-process the inference model and can be controlled by passing difference
* command flags;
* the other way is to compose inside the inference API as a runtime pre-process
* phase in the
* inference service.
*/
#include <gflags/gflags.h>
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
namespace paddle {
namespace inference {
namespace analysis {
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root);
class Analyzer : public OrderedRegistry<PassManager> {
public:
// Register all the pass-managers.
Analyzer();
void Run(Argument* argument);
DISABLE_COPY_AND_ASSIGN(Analyzer);
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, main) {
Analyzer analyser;
analyser.Run(&argument);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -41,6 +41,9 @@ struct Argument {
// The original program desc.
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc;
// The processed program desc.
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc;
};
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace inference {
namespace analysis {
// It is a better idea that the inputs and outputs of this graph is set manully
// It is a better idea that the inputs and outputs of this graph is set manually
// before, but there must be a Pass that helps to prune the unnecessary ops that
// do not contribute to the given targets, so in this pass, analysis and get the
// inputs and outputs is OK.
......@@ -50,6 +50,25 @@ void DataFlowGraph::Build() {
outputs.push_back(out);
}
}
Clean();
}
void DataFlowGraph::Clean() {
for (auto &node : nodes.nodes()) {
std::unordered_set<Node *> inlinks_set(node->inlinks.begin(),
node->inlinks.end());
std::unordered_set<Node *> outlinks_set(node->outlinks.begin(),
node->outlinks.end());
if (inlinks_set.size() < node->inlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->inlinks.assign(inlinks_set.begin(), inlinks_set.end());
}
if (outlinks_set.size() < node->outlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->outlinks.assign(outlinks_set.begin(), outlinks_set.end());
}
}
}
std::string DataFlowGraph::DotString() const {
......
......@@ -47,6 +47,10 @@ struct DataFlowGraph {
// Output a DOT graph file for debug.
std::string DotString() const;
private:
// Remove duplicate edges and so on.
void Clean();
};
/*
......@@ -133,17 +137,24 @@ struct GraphTraits<DataFlowGraph> {
// Extract the inputs and outputs of a graph. The inputs and outputs of a
// sub-graph is the inputs nodes and output nodes that doesn't inside the
// sub-graph.
std::pair<
std::vector<Node *>,
std::vector<
Node *>> static ExtractInputAndOutputOfSubGraph(std::vector<Node *>
&graph) {
static std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) {
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
std::unordered_set<Node *> inputs;
std::unordered_set<Node *> outputs;
// Input a Value, check whether its inlink is in the subgraph.
auto inlink_in_subgraph = [&](Node *n) {
for (auto *in : n->inlinks) {
if (nodes.count(in)) return true;
}
return false;
};
for (auto &node : graph) {
for (auto *in : node->inlinks) {
if (!nodes.count(in) && in->type() == Node::Type::kValue) {
// The Value that is written by nodes inside a sub-graph shouldn't be the
// input of the sub-graph.
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
!inlink_in_subgraph(in)) {
inputs.insert(in);
}
}
......
......@@ -13,21 +13,34 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
using framework::proto::ProgramDesc;
std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>>& nodes);
bool DataFlowGraphToFluidPass::Initialize(Argument* argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
desc_ = argument->origin_program_desc.get();
// Here some logic from program_desc.cc and will not add new interfaces into
// framework::ProgramDesc class, use some UT to assure the correctness.
auto* block = desc_->mutable_blocks()->Add();
block->set_idx(framework::kRootBlockIndex);
block->set_parent_idx(framework::kNoneBlockIndex);
PADDLE_ENFORCE(!argument->transformed_program_desc);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph.
argument->transformed_program_desc.reset(
new ProgramDesc(*argument->origin_program_desc));
argument->transformed_program_desc->mutable_blocks(framework::kRootBlockIndex)
->clear_ops();
desc_ = argument->transformed_program_desc.get();
argument_ = argument;
return true;
}
......@@ -37,14 +50,17 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
auto traits = GraphTraits<DataFlowGraph>(graph);
for (auto it = traits.nodes().begin(); it != traits.nodes().end(); ++it) {
if (it->deleted()) continue;
switch (it->type()) {
case Node::Type::kFunction:
LOG(INFO) << "add function " << it->name();
case Node::Type::kFunction: {
LOG(INFO) << "add function " << it->repr();
AddFluidOp(&(*it));
break;
case Node::Type::kFunctionBlock:
} break;
case Node::Type::kFunctionBlock: {
LOG(INFO) << "add engine op " << it->repr() << " , "
<< static_cast<FunctionBlock*>(&(*it))->subgraph.size();
AddEngineOp(&(*it));
break;
} break;
default:
continue;
}
......@@ -52,12 +68,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph* graph) {
}
void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
LOG(INFO) << "processing func " << node->name();
auto* ori_op = static_cast<framework::proto::OpDesc*>(node->pb_desc());
// currently only the main block is analyzed.
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops();
LOG(INFO) << "to copy the op";
*op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt
......@@ -65,11 +79,89 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node* node) {
// NOTE It might be changed by other passes in the long run.
}
void CreateTrtEngineOp(Node* node, const DataFlowGraph& graph,
const framework::proto::BlockDesc& block) {
static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc;
auto* func = static_cast<FunctionBlock*>(node);
// collect inputs
std::vector<std::string> io;
for (auto* x : func->inlinks) {
io.push_back(x->name());
}
desc.SetInput("Xs", io);
// collect outputs
io.clear();
for (auto* x : func->outlinks) {
io.push_back(x->name());
}
desc.SetOutput("Ys", io);
desc.SetType("tensorrt_engine");
// Set attrs
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
SetAttr(desc.Proto(), "engine_unique_key",
"trt-" + std::to_string(counter++));
SetAttr(desc.Proto(), "max_batch", 100); // TODO(Superjomn) add config latter
SetAttr(desc.Proto(), "max_workspace",
1024); // TODO(Superjomn) add config latter
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
node->SetPbMsg(desc.Proto()->SerializeAsString());
}
std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>>& nodes) {
std::vector<std::string> parameters;
for (const auto& node : nodes) {
if (!node->IsValue()) continue;
PADDLE_ENFORCE(!node->pb_msg().empty(), "pb_msg should be set first");
framework::proto::VarDesc var;
var.ParseFromString(node->pb_msg());
if (var.persistable()) {
parameters.push_back(var.name());
}
}
return parameters;
}
void DataFlowGraphToFluidPass::AddEngineOp(Node* node) {
// auto* ori_op = static_cast<framework::proto::OpDesc*>(node->extra_info());
// auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
// auto* op = main_block->add_ops();
// TODO(Superjomn) Here need to expose some arguments for default setting.
PADDLE_ENFORCE(node->IsFunctionBlock());
auto* block_node = static_cast<FunctionBlock*>(node);
framework::proto::BlockDesc proto;
framework::BlockDesc block_desc(nullptr, &proto);
// copy ops.
for (auto* node : block_node->subgraph) {
auto* op = block_desc.AppendOp();
PADDLE_ENFORCE(!node->pb_msg().empty());
op->Proto()->ParseFromString(node->pb_msg());
}
CreateTrtEngineOp(node, *argument_->main_dfg, *block_desc.Proto());
auto* main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto* op = main_block->add_ops();
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
op->ParseFromString(node->pb_msg());
}
namespace {
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
public:
using Config = DFG_GraphvizDrawPass::Config;
DFG_DebuggerPass(const Config& config) : DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "dfg-to-fluid-debuger-pass"; }
bool Finalize() override { return true; }
};
}
Pass* DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger"));
}
} // namespace analysis
......
......@@ -40,10 +40,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
return "Transform a DFG to a Fluid ProgramDesc";
}
Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override {
return nullptr;
}
Pass *CreateGraphvizDebugerPass() const override;
protected:
// Add a Fluid Op into the ProgramDesc.
......@@ -53,6 +50,7 @@ class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
private:
framework::proto::ProgramDesc *desc_;
Argument *argument_;
};
} // namespace analysis
} // namespace inference
......
......@@ -18,12 +18,19 @@ namespace paddle {
namespace inference {
namespace analysis {
int DFG_GraphvizDrawPass::counter_{0};
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto content = Draw(graph);
std::ofstream file(GenDotPath());
auto dot_path = GenDotPath();
std::ofstream file(dot_path);
file.write(content.c_str(), content.size());
file.close();
LOG(INFO) << "draw dot to " << GenDotPath();
auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png";
std::string message;
LOG(INFO) << "draw to " << png_path;
ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message);
}
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
......@@ -41,9 +48,7 @@ std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue;
for (auto &in : node.inlinks) {
dot.AddEdge(in->repr(), node.repr(), {});
}
dot.AddEdge(in->repr(), node.repr(), {});
}
}
return dot.Build();
......
......@@ -50,20 +50,25 @@ class DFG_GraphvizDrawPass : public DataFlowGraphPass {
bool Initialize(Argument *argument) override { return true; }
void Run(DataFlowGraph *graph) override;
bool Finalize() override { return Pass::Finalize(); }
bool Finalize() override { return true; }
std::string repr() const override { return "DFG graphviz drawer"; }
std::string description() const override {
return "Debug a DFG by draw with graphviz";
}
private:
protected:
// A counter to add a number prefix to the debugger image output so that they
// will sort in the triggered order.
static int counter_;
// Path of the dot file to output.
std::string GenDotPath() const {
return config_.dir + "/" + "graph_" + config_.id + ".dot";
return config_.dir + "/" + std::to_string(counter_++) + "-graph_" +
config_.id + ".dot";
}
std::string Draw(DataFlowGraph *graph);
virtual std::string Draw(DataFlowGraph *graph);
Config config_;
};
......
......@@ -31,7 +31,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
pass.Run(&dfg);
// test content
std::ifstream file("./graph_test.dot");
std::ifstream file("./0-graph_test.dot");
ASSERT_TRUE(file.is_open());
std::string line;
......@@ -40,7 +40,7 @@ TEST_F(DFG_Tester, dfg_graphviz_draw_pass_tester) {
no++;
}
// DFG is sensitive to ProgramDesc, be careful to change the existing models.
ASSERT_EQ(no, 112);
ASSERT_EQ(no, 82);
}
} // namespace analysis
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle {
......@@ -33,7 +35,7 @@ bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
return true;
}
bool FluidToDataFlowGraphPass::Finalize() { return Pass::Finalize(); }
bool FluidToDataFlowGraphPass::Finalize() { return true; }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph);
......@@ -46,6 +48,7 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
auto *v = graph->nodes.Create(Node::Type::kValue);
v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id();
}
for (int i = 0; i < main_block.ops_size(); i++) {
......@@ -56,6 +59,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
// Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs
// TODO(Superjomn) make sure the InputNames is the real variable name.
for (int j = 0; j < op.inputs_size(); j++) {
......@@ -79,9 +84,19 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
graph->Build();
}
Pass *FluidToDataFlowGraphPass::CreatePrinterPass(
std::ostream &os, const std::string &banner) const {
return nullptr;
namespace {
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
public:
using Config = DFG_GraphvizDrawPass::Config;
DFG_DebuggerPass(const Config &config) : DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "fluid-to-dfg-debuger-pass"; }
bool Finalize() override { return true; }
};
}
Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, "fluid-to-dfg-debuger"));
}
} // namespace analysis
......
......@@ -46,8 +46,7 @@ class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
return "transform a fluid ProgramDesc to a data flow graph.";
}
Pass *CreatePrinterPass(std::ostream &os,
const std::string &banner) const override;
Pass *CreateGraphvizDebugerPass() const override;
private:
framework::proto::ProgramDesc const *desc_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace inference {
namespace analysis {
template <>
void SetAttr<std::string>(framework::proto::OpDesc *op, const std::string &name,
const std::string &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(data);
}
template <>
void SetAttr<int>(framework::proto::OpDesc *op, const std::string &name,
const int &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc *op, const std::string &name,
const int64_t &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data);
}
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc *op,
const std::string &name,
const std::vector<std::string> &data) {
auto *attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto &s : data) {
attr->add_strings(s.c_str());
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -26,6 +28,10 @@ namespace paddle {
namespace inference {
namespace analysis {
template <typename T>
void SetAttr(framework::proto::OpDesc *op, const std::string &name,
const T &data);
template <typename Vec>
int AccuDims(Vec &&vec, int size) {
int res = 1;
......@@ -93,7 +99,7 @@ template <typename T>
class OrderedRegistry {
public:
T *Register(const std::string &name, T *x) {
PADDLE_ENFORCE(!dic_.count(name));
PADDLE_ENFORCE(!dic_.count(name), "duplicate key [%s]", name);
dic_[name] = data_.size();
data_.emplace_back(std::unique_ptr<T>(x));
return data_.back().get();
......@@ -117,6 +123,20 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) {
return *var->GetMutable<T>();
}
static void ExecShellCommand(const std::string &cmd, std::string *message) {
char buffer[128];
std::shared_ptr<FILE> pipe(popen(cmd.c_str(), "r"), pclose);
if (!pipe) {
LOG(ERROR) << "error running command: " << cmd;
return;
}
while (!feof(pipe.get())) {
if (fgets(buffer, 128, pipe.get()) != nullptr) {
*message += buffer;
}
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......
......@@ -20,6 +20,17 @@ namespace paddle {
namespace inference {
namespace analysis {
template <>
std::string &NodeAttr::As<std::string>() {
if (data_.empty()) {
type_hash_ = typeid(std::string).hash_code();
}
PADDLE_ENFORCE_EQ(type_hash_, typeid(std::string).hash_code());
return data_;
}
std::string &NodeAttr::String() { return As<std::string>(); }
std::vector<Dot::Attr> Value::dot_attrs() const {
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
Dot::Attr("shape", "box"),
......
......@@ -35,6 +35,44 @@ namespace analysis {
class NodeMap;
// A helper class to maintain the status from Pass.
struct NodeAttr {
// NOTE T should be a primary type or a struct combined by several primary
// types.
// NOTE the STL containers should not use here.
// Some usages
// Attr attr;
// attr.Bool() = true;
bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
std::string &String();
private:
template <typename T>
T &As() {
// init storage in the first usage.
if (data_.empty()) {
VLOG(4) << "resize data to " << sizeof(T);
type_hash_ = typeid(T).hash_code();
data_.resize(sizeof(T));
}
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(),
"type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_hash_),
DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]);
}
private:
std::string data_;
size_t type_hash_{std::numeric_limits<size_t>::max()};
};
/*
* Node Representation.
*
......@@ -50,8 +88,6 @@ class Node {
Node() = default;
struct Attr;
// Cast to a subclass type, Function for example.
template <typename Subclass>
Subclass &As() {
......@@ -71,7 +107,7 @@ class Node {
// Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists.
Attr &attr(const std::string &name) const { return attrs_[name]; }
NodeAttr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; }
......@@ -80,6 +116,9 @@ class Node {
void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; }
void *pb_desc() const { return attr("pb_desc").Pointer(); }
void SetPbMsg(const std::string &s) { attr("pb_msg").String() = s; }
const std::string &pb_msg() const { return attr("pb_msg").String(); }
void SetDeleted() { deleted_ = true; }
bool deleted() const { return deleted_; }
......@@ -94,43 +133,6 @@ class Node {
// Output links.
std::vector<Node *> outlinks;
// A helper class to maintain the status from Pass.
struct Attr {
// NOTE T should be a primary type or a struct combined by several primary
// types.
// NOTE the STL containers should not use here.
// Some usages
// Attr attr;
// attr.Bool() = true;
bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
private:
template <typename T>
T &As() {
// init storage in the first usage.
if (data_.empty()) {
VLOG(4) << "resize data to " << sizeof(T);
type_hash_ = typeid(T).hash_code();
data_.resize(sizeof(T));
}
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(),
"type not matched, origin is %s, want %s",
DataTypeNamer::Global().repr(type_hash_),
DataTypeNamer::Global().repr<T>());
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
return *reinterpret_cast<T *>(&data_[0]);
}
private:
std::string data_;
size_t type_hash_{std::numeric_limits<size_t>::max()};
};
// Type checks.
bool IsFunction() const { return type_ == Node::Type::kFunction; }
bool IsValue() const { return type_ == Node::Type::kValue; }
......@@ -150,7 +152,7 @@ class Node {
Type type_{Type::kNone};
// Mark this node is deleted by some pass.
bool deleted_{false};
mutable std::unordered_map<std::string, Attr> attrs_;
mutable std::unordered_map<std::string, NodeAttr> attrs_;
};
class Function;
......@@ -213,6 +215,10 @@ class Function : public Node {
struct FunctionBlock : public Node {
std::string repr() const override { return "block-" + std::to_string(id()); }
std::vector<Node *> subgraph;
protected:
FunctionBlock() { SetType(Node::Type::kFunctionBlock); }
friend class NodeMap;
};
class NodeMap {
......@@ -227,7 +233,7 @@ class NodeMap {
void Delete(size_t id);
const std::vector<std::unique_ptr<Node>> &nodes() { return nodes_; }
const std::vector<std::unique_ptr<Node>> &nodes() const { return nodes_; }
size_t size() const { return nodes_.size(); }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file contains all the flags that declared in Node::Attr.
*
* The Node::Attr is designed to share information between different passes, one
* can get other's attributes in a Node by the flags in this file.
*/
#pragma once
namespace paddle {
namespace inference {
namespace analysis {
#define DECLARE_NODE_ATTR(flag__) const char ATTR_##flag__[] = #flag__;
DECLARE_NODE_ATTR(supported_by_tensorrt) // bool
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -60,6 +60,9 @@ class Pass {
return nullptr;
}
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual Pass *CreateGraphvizDebugerPass() const { return nullptr; }
// Run on a single Node.
virtual void Run(Node *x) { LOG(FATAL) << "not valid"; }
// Run on a single Function.
......
......@@ -19,6 +19,18 @@ namespace paddle {
namespace inference {
namespace analysis {
bool PassManager::Initialize(Argument* argument) {
argument_ = argument;
for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr();
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_);
for (auto& pass : data_) {
......
......@@ -50,17 +50,7 @@ class PassManager : public OrderedRegistry<Pass> {
// globally shared, so pass them as the arguemnts for all the pass managers.
virtual bool Initialize(const Argument& argument) { return false; }
virtual bool Initialize(Argument* argument) {
argument_ = argument;
for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr();
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
virtual bool Initialize(Argument* argument);
// Call all the passes' Finalize methods.
virtual bool Finalize() {
......
......@@ -64,6 +64,7 @@ TEST_F(DFG_Tester, DFG_pass_manager) {
manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
ASSERT_TRUE(&argument);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
......
......@@ -119,10 +119,12 @@ void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); }
void SubGraphFuse::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) {
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph
// as deleted. 3. Replace the deleted node with the new Block Node.
auto *block_node = graph_->nodes.Create(Node::Type::kFunctionBlock);
auto *block_node = static_cast<FunctionBlock *>(
graph_->nodes.Create(Node::Type::kFunctionBlock));
auto io = ExtractInputAndOutputOfSubGraph(subgraph);
block_node->inlinks = std::move(io.first);
block_node->outlinks = std::move(io.second);
......@@ -130,21 +132,25 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
// TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass.
node->SetDeleted();
block_node->subgraph.push_back(node);
}
std::unordered_map<Node *, Node *>
delelte_node_map; // deleted node to BlockNode
for (auto *n : block_node->inlinks) {
n->inlinks.clear();
}
for (auto *n : block_node->outlinks) {
n->outlinks.clear();
}
for (auto *n : block_node->inlinks) {
n->outlinks.push_back(block_node);
// Change all the sub-graph's inputs and outputs corresponding inlink and
// outlink to this sub-graph node.
auto inlink_or_outlink_cleaner = [&](std::vector<Node *> &nodes) {
for (auto *&n : nodes) {
if (subgraph_uniq.count(n)) {
n = block_node;
}
}
std::unordered_set<Node *> uniq(nodes.begin(), nodes.end());
nodes.assign(uniq.begin(), uniq.end());
};
for (auto *i : block_node->inlinks) {
inlink_or_outlink_cleaner(i->outlinks);
}
for (auto *n : block_node->outlinks) {
n->inlinks.push_back(n);
for (auto *&o : block_node->outlinks) {
inlink_or_outlink_cleaner(o->inlinks);
}
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/node_attr_flags.h"
namespace paddle {
namespace inference {
namespace analysis {
void TensorRTSubgraphNodeMarkPass::Run(DataFlowGraph *graph) {
for (auto &node : graph->nodes.nodes()) {
node->attr(ATTR_supported_by_tensorrt).Bool() = teller_(node.get());
}
}
class DfgDebuggerPass : public DFG_GraphvizDrawPass {
public:
DfgDebuggerPass(const DFG_GraphvizDrawPass::Config &config)
: DFG_GraphvizDrawPass(config) {}
std::string repr() const override {
return "tensorrt-subgraph-node-mark-debugger";
}
bool Finalize() override { return true; }
protected:
std::string Draw(DataFlowGraph *graph) override {
Dot dot;
// Add nodes
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (config_.display_deleted_node || !node.deleted()) {
auto dot_attr = node.dot_attrs();
if (node.attr(ATTR_supported_by_tensorrt).Bool()) {
dot_attr.assign(
{Dot::Attr{"color", "green"}, Dot::Attr{"style", "filled"}});
}
dot.AddNode(node.repr(), dot_attr);
}
}
// Add edges
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &in : node.inlinks) {
if (!config_.display_deleted_node && in->deleted()) continue;
dot.AddEdge(in->repr(), node.repr(), {});
}
}
return dot.Build();
}
};
Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
DFG_GraphvizDrawPass::Config config(
FLAGS_inference_analysis_graphviz_log_root, "tensorrt_marked_node");
return new DfgDebuggerPass(config);
}
bool TensorRTSubgraphNodeMarkPass::Finalize() { return true; }
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines TensorRTSubgraphNodeMarkPass which helps to mark the ops
* that supported by TensorRT engine.
*/
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Mark the operators that TensorRT engine supports.
*/
class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass {
public:
using teller_t = SubGraphSplitter::NodeInsideSubgraphTeller;
TensorRTSubgraphNodeMarkPass(const teller_t& teller) : teller_(teller) {}
bool Initialize(Argument* argument) override { return true; }
// This class get a sub-graph as input and determine whether to transform this
// sub-graph into TensorRT.
void Run(DataFlowGraph* graph) override;
std::string repr() const { return "tensorrt-sub-subgraph-mark"; }
std::string description() const { return "tensorrt sub-graph mark pass"; }
Pass* CreateGraphvizDebugerPass() const override;
bool Finalize() override;
private:
teller_t teller_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/node_attr_flags.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST_F(DFG_Tester, tensorrt_subgraph_node_mark_pass) {
// init
FluidToDataFlowGraphPass pass;
ASSERT_TRUE(pass.Initialize(&argument));
argument.main_dfg.reset(new DataFlowGraph);
pass.Run(argument.main_dfg.get());
TensorRTSubgraphNodeMarkPass::teller_t teller = [](const Node* node) {
return node->IsFunction() &&
static_cast<const Function*>(node)->func_type() == "mul";
};
TensorRTSubgraphNodeMarkPass pass1(teller);
ASSERT_TRUE(pass1.Initialize(&argument));
pass1.Run(argument.main_dfg.get());
int counter{0};
for (auto& node : argument.main_dfg->nodes.nodes()) {
counter += node->attr(ATTR_supported_by_tensorrt).Bool();
}
LOG(INFO) << counter << " nodes marked";
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -24,7 +24,7 @@ TensorRTSubGraphPass::TensorRTSubGraphPass(
: node_inside_subgraph_teller_(teller) {}
void TensorRTSubGraphPass::Run(DataFlowGraph *graph) {
SubGraphFuse(graph, node_inside_subgraph_teller_);
SubGraphFuse(graph, node_inside_subgraph_teller_)();
}
} // namespace analysis
......
......@@ -38,6 +38,11 @@ class TensorRTSubGraphPass : public DataFlowGraphPass {
// sub-graph into TensorRT.
void Run(DataFlowGraph* graph) override;
bool Finalize() override { return true; }
std::string repr() const { return "tensorrt-sub-graph"; }
std::string description() const { return "tensorrt sub graph pass"; }
private:
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
};
......
......@@ -23,49 +23,48 @@ namespace paddle {
namespace inference {
namespace analysis {
DEFINE_string(model_dir, "", "inference test model dir");
DEFINE_string(dot_dir, "./", "");
TEST(TensorRTSubGraph, single_pass) {
auto desc = LoadProgramDesc();
auto dfg = ProgramDescToDFG(desc);
SubGraphSplitter::NodeInsideSubgraphTeller teller = [](const Node* node) {
TEST_F(DFG_Tester, tensorrt_single_pass) {
std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "sigmoid"});
SubGraphSplitter::NodeInsideSubgraphTeller teller = [&](const Node* node) {
if (node->type() != Node::Type::kFunction) return false;
const auto* func = static_cast<const Function*>(node);
if (func->func_type() == "elementwise_add" || func->func_type() == "relu" ||
func->func_type() == "conv2d" || func->func_type() == "mul" ||
func->func_type() == "sigmoid" || func->func_type() == "softmax") {
LOG(INFO) << "sub-graph marked " << node->repr();
return true;
}
if (teller_set.count(func->func_type())) return true;
return false;
};
DFG_GraphvizDrawPass::Config config{"./", "test"};
DFG_GraphvizDrawPass dfg_pass(config);
dfg_pass.Initialize();
DFG_GraphvizDrawPass dfg_pass1(config);
dfg_pass1.Initialize();
dfg_pass.Run(&dfg);
LOG(INFO) << "init";
DFG_GraphvizDrawPass::Config config{FLAGS_dot_dir, "origin"};
DFG_GraphvizDrawPass::Config config1{FLAGS_dot_dir, "fusion"};
DFG_GraphvizDrawPass dfg_pass(config);
DFG_GraphvizDrawPass dfg_pass1(config1);
FluidToDataFlowGraphPass pass0;
TensorRTSubGraphPass trt_pass(std::move(teller));
trt_pass.Initialize();
trt_pass.Run(&dfg);
LOG(INFO) << "Initialize";
dfg_pass.Initialize(&argument);
dfg_pass1.Initialize(&argument);
pass0.Initialize(&argument);
trt_pass.Initialize(&argument);
dfg_pass1.Run(&dfg);
LOG(INFO) << "Run";
argument.main_dfg.reset(new DataFlowGraph);
pass0.Run(argument.main_dfg.get());
dfg_pass.Run(argument.main_dfg.get());
trt_pass.Run(argument.main_dfg.get());
dfg_pass1.Run(argument.main_dfg.get());
// Check the TRT op's block desc
for (auto node : dfg.nodes.nodes()) {
for (auto& node : argument.main_dfg->nodes.nodes()) {
if (node->IsFunctionBlock()) {
LOG(INFO) << "get function block";
}
}
}
TEST(TensorRTSubGraph, pass_manager) {}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -226,7 +226,8 @@ op_library(sequence_softmax_op DEPS softmax)
if (WITH_GPU AND TENSORRT_FOUND)
op_library(tensorrt_engine_op DEPS tensorrt_engine)
nv_test(test_tensorrt_engine_op SRCS tensorrt_engine_op_test.cc
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter)
DEPS tensorrt_engine_op tensorrt_engine tensorrt_converter
analysis)
else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif()
......
......@@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
"Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
}
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension");
......
......@@ -282,6 +282,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
if (grad.rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
// merge duplicated rows if any.
scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge =
......
......@@ -19,28 +19,28 @@ namespace operators {
template <>
void GetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t* num_updates_,
int64_t* num_accumulates_, int64_t* old_num_accumulates_) {
const framework::ExecutionContext& ctx, int64_t* num_updates,
int64_t* num_accumulates, int64_t* old_num_accumulates) {
auto* in_old_num_accumulates = ctx.Input<Tensor>("in_old_num_accumulates");
auto* in_num_accumulates = ctx.Input<Tensor>("in_num_accumulates");
auto* in_num_updates = ctx.Input<Tensor>("in_num_updates");
*old_num_accumulates_ = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates_ = in_num_accumulates->data<int64_t>()[0];
*num_updates_ = in_num_updates->data<int64_t>()[0];
*old_num_accumulates = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates = in_num_accumulates->data<int64_t>()[0];
*num_updates = in_num_updates->data<int64_t>()[0];
}
template <>
void SetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t num_updates_,
int64_t num_accumulates_, int64_t old_num_accumulates_) {
const framework::ExecutionContext& ctx, int64_t num_updates,
int64_t num_accumulates, int64_t old_num_accumulates) {
auto* out_old_num_accumulates = ctx.Output<Tensor>("out_old_num_accumulates");
auto* out_num_accumulates = ctx.Output<Tensor>("out_num_accumulates");
auto* out_num_updates = ctx.Output<Tensor>("out_num_updates");
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates_;
out_num_accumulates->data<int64_t>()[0] = num_accumulates_;
out_num_updates->data<int64_t>()[0] = num_updates_;
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates;
out_num_accumulates->data<int64_t>()[0] = num_accumulates;
out_num_updates->data<int64_t>()[0] = num_updates;
}
class AverageAccumulatesOp : public framework::OperatorWithKernel {
......@@ -177,7 +177,7 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
AverageAccumulates Operator.
Accumulate the sum of parameter whtin sliding window. The size of sliding window is
Accumulate the sum of parameter within sliding window. The size of sliding window is
determined by 'average_window', 'max_average_window' and 'min_average_window'.
Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'.
'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'.
......
......@@ -54,8 +54,9 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
float average_window = ctx.Attr<float>("average_window");
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
min_average_window =
std::min<int64_t>(min_average_window, max_average_window);
PADDLE_ENFORCE_LE(min_average_window, max_average_window,
"min_average_window shouldn't be larger than "
"max_average_window");
// Get inputs
auto* param = ctx.Input<Tensor>("param");
......
......@@ -26,8 +26,12 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
"Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FillZerosLikeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
if (ctx->IsRuntime() &&
ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR_ARRAY) {
return; // skip runtime infershape when is tensor array;
}
}
};
......@@ -39,7 +43,7 @@ class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
FillZerosLike Operator.
Fill up a variable with zeros.
Fill up a variable with zeros, supporting both LoDTensor and LoDTensorArray.
The output will have the same size as the input.
)DOC");
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -23,12 +24,29 @@ template <typename DeviceContext, typename T>
class FillZerosLikeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), out,
static_cast<T>(0));
auto var = context.InputVar("X");
if (var->IsType<framework::LoDTensor>()) {
auto& input = *context.Input<framework::LoDTensor>("X");
auto& output = *context.Output<framework::LoDTensor>("Out");
output.Resize(input.dims());
output.set_lod(input.lod());
output.mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), &(output),
static_cast<T>(0));
} else if (var->IsType<framework::LoDTensorArray>()) {
auto& input = *context.Input<framework::LoDTensorArray>("X");
auto& output = *context.Output<framework::LoDTensorArray>("Out");
output.resize(input.size());
for (auto i = 0; i < input.size(); i++) {
output[i].Resize(input[i].dims());
output[i].set_lod(input[i].lod());
output[i].mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> setter;
setter(context.template device_context<DeviceContext>(), &(output[i]),
static_cast<T>(0));
}
}
}
};
......
......@@ -53,6 +53,7 @@ template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
VLOG(4) << "TensorRTEngineKernel executing";
auto engine_name = context.Attr<std::string>("engine_uniq_key");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context);
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
......@@ -51,48 +52,10 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
*var = *desc.Proto();
}
template <typename T>
void SetAttr(framework::proto::OpDesc* op, const std::string& name,
const T& data);
template <>
void SetAttr<std::string>(framework::proto::OpDesc* op, const std::string& name,
const std::string& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(data);
}
template <>
void SetAttr<int>(framework::proto::OpDesc* op, const std::string& name,
const int& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(data);
}
template <>
void SetAttr<int64_t>(framework::proto::OpDesc* op, const std::string& name,
const int64_t& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data);
}
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc* op,
const std::string& name,
const std::vector<std::string>& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto& s : data) {
attr->add_strings(s.c_str());
}
}
} // namespace
using inference::analysis::SetAttr;
TEST(TensorRTEngineOp, manual) {
framework::ProgramDesc program;
auto* block_ = program.Proto()->add_blocks();
......
......@@ -106,6 +106,8 @@ function cmake_gen() {
-DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF}
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
-DWITH_CONTRIB=${WITH_CONTRIB:-ON}
-DWITH_ANAKIN=${WITH_ANAKIN:-ON}
-DWITH_INFERENCE_DEMO=${WITH_INFERENCE_DEMO:-ON}
========================================
EOF
# Disable UNITTEST_USE_VIRTUALENV in docker because
......@@ -133,7 +135,8 @@ EOF
-DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DWITH_CONTRIB=${WITH_CONTRIB:-ON} \
-DWITH_ANAKIN=${WITH_ANAKIN:-ON}
-DWITH_ANAKIN=${WITH_ANAKIN:-ON} \
-DWITH_INFERENCE_DEMO=${WITH_INFERENCE_DEMO:-ON}
}
function abort(){
......
......@@ -111,7 +111,7 @@ def fetch():
paddle.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
paddle.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5)
def convert(path):
......
......@@ -27,6 +27,7 @@ __all__ = [
'Variable',
'Program',
'Operator',
'Parameter',
'default_startup_program',
'default_main_program',
'program_guard',
......@@ -1922,7 +1923,7 @@ def program_guard(main_program, startup_program=None):
def get_var(name, program=None):
"""
Get a variable by name from the global block of a program.
Args:
name(str): name of the variable
program(Program|None): program object.
......
......@@ -95,6 +95,7 @@ __all__ = [
'relu',
'log',
'crop',
'fill_zeros_like',
]
......@@ -5077,12 +5078,12 @@ def mean_iou(input, label, num_classes):
out_correct = helper.create_tmp_variable(dtype='int32')
helper.append_op(
type="mean_iou",
inputs={"predictions": input,
"labels": label},
inputs={"Predictions": input,
"Labels": label},
outputs={
"out_mean_iou": out_mean_iou,
"out_wrong": out_wrong,
"out_correct": out_correct
"OutMeanIou": out_mean_iou,
"OutWrong": out_wrong,
"OutCorrect": out_correct
},
attrs={"num_classes": num_classes})
return out_mean_iou, out_wrong, out_correct
......@@ -5184,3 +5185,40 @@ def crop(x, shape=None, offsets=None, name=None):
outputs={'Out': out},
attrs=None if len(attrs) == 0 else attrs)
return out
def fill_zeros_like(x):
"""
This layer takes an input and outputs a variable that has the same structure as
the input and with all the element values as zero. The variable can be a Tensor
or TensorArray.
.. code-block:: text
Given
X = [[0, 1, 2, 0],
[0, 3, 4, 0],
[0, 0, 0, 0]],
output is:
Out = [[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]].
Args:
x (Variable): The input variable, which could be a tensor or tensor array
Returns:
Variable: The zero-filled variable, which has the same type and shape as
the input variable.
Examples:
.. code-block:: python
y = fluid.layers.fill_zeros_like(x)
"""
helper = LayerHelper('fill_zeros_like', **locals())
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='fill_zeros_like', inputs={'X': [x]}, outputs={'Out': [out]})
return out
......@@ -18,15 +18,16 @@ import numpy as np
__all__ = ['create_lod_tensor', 'create_random_int_lodtensor']
def create_lod_tensor(data, lod, place):
def create_lod_tensor(data, recursive_seq_lens, place):
"""
Create a lod tensor from a numpy array, a list, or an existing lod tensor.
Create a lod tensor by doing the following:
1. Check that the length-based input lod is valid.
1. Check that the length-based level of detail (LoD) also known as
recursive_sequence_lengths of the input is valid.
2. Convert the length-based lod to a offset-based LoD.
2. Convert recursive_sequence_lengths to a offset-based LoD.
3. Copy the data from a numpy array, a list or a existing lod tensor to
CPU or GPU device (based on input place).
......@@ -37,45 +38,47 @@ def create_lod_tensor(data, lod, place):
Suppose we want LoDTensor to hold data for sequences of word, where each
word is represented by an integer. If we want to create a LoDTensor to
represent two sentences, one of 2 words, and one of 3 words.
represent two sentences, one of 2 words, and one of 3 words.
Then :code:`data` can be a numpy array of integers with shape (5, 1).
:code:`lod` will be [[2, 3]], indicating the length(# of words) in each
sentence. This length-based input lod [[2, 3]] will be converted to
offset-based lod [[0, 2, 5]] inside the function call.
:code:`recursive_seq_lens` will be [[2, 3]], indicating the length(# of words) in each
sentence. This length-based :code:`recursive_seq_lens` [[2, 3]] will be converted to
offset-based LoD [[0, 2, 5]] inside the function call.
Please reference :ref:`api_guide_low_level_lod_tensor` for more details
regarding LoD.
Args:
data(numpy.ndarray|list|LoDTensor): a numpy array or a LoDTensor or a
list holding the data to be copied.
lod(list): a list of lists indicating the length-based LoD info
specified by the user.
list holding the data to be copied.
recursive_seq_lens(list): a list of lists indicating the length-based level of detail
info specified by the user.
place(Place): CPU or GPU place indicating where the data in the new
LoDTensor will be stored.
Returns:
A fluid LoDTensor object with tensor data and lod info.
A fluid LoDTensor object with tensor data and recursive_seq_lens info.
"""
if isinstance(data, core.LoDTensor):
return create_lod_tensor(np.array(data), lod, place)
return create_lod_tensor(np.array(data), recursive_seq_lens, place)
elif isinstance(data, list):
# When input data is a list, it only deal with the case where the base element
# is an index of shape [1] and dtype int64 (e.g., word id). Hence, the generated
# LoDTensor will be of shape [n, 1] and dtype int64, where `n` is the total number
# of words or other indexes in the sequence.
new_lod = []
new_recursive_seq_lens = []
for seq in data:
new_lod.append(len(seq))
assert [new_lod] == lod, "data and lod do not match"
new_recursive_seq_lens.append(len(seq))
assert [
new_recursive_seq_lens
] == recursive_seq_lens, "data and recursive_seq_lens do not match"
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
return create_lod_tensor(flattened_data, lod, place)
return create_lod_tensor(flattened_data, recursive_seq_lens, place)
elif isinstance(data, np.ndarray):
tensor = core.LoDTensor()
tensor.set(data, place)
tensor.set_recursive_sequence_lengths(lod)
tensor.set_recursive_sequence_lengths(recursive_seq_lens)
assert tensor.has_valid_recursive_sequence_lengths(
), "the provided lod info is invalid"
return tensor
......@@ -84,7 +87,8 @@ def create_lod_tensor(data, lod, place):
"data should be either a LoDTensor, a Numpy array or a list")
def create_random_int_lodtensor(lod, base_shape, place, low, high):
def create_random_int_lodtensor(recursive_seq_lens, base_shape, place, low,
high):
"""
Create a LoDTensor containing random integers.
......@@ -95,7 +99,7 @@ def create_random_int_lodtensor(lod, base_shape, place, low, high):
The function does the following:
1. Calculate the overall shape of the LoDTensor based on the length-based
:code:`lod` input and the shape of the basic element in
:code:`recursive_seq_lens` input and the shape of the basic element in
:code:`base_shape`.
2. Create a numpy array of this shape.
......@@ -105,12 +109,13 @@ def create_random_int_lodtensor(lod, base_shape, place, low, high):
Suppose we want LoDTensor to hold data for sequences of word, where each
word is represented by an integer. If we want to create a LoDTensor to
represent two sentences, one of 2 words, and one of 3 words. Then
'base_shape' is [1], input length-based 'lod' is [[2, 3]]. Then the overall
shape of the LoDTensor would be [5, 1], holding 5 words for two sentences.
'base_shape' is [1], input length-based 'recursive_seq_lens' is [[2, 3]].
Then the overall shape of the LoDTensor would be [5, 1], holding 5 words
for two sentences.
Args:
lod(list): a list of lists indicating the length-based LoD info
specified by the user.
recursive_seq_lens(list): a list of lists indicating the length-based
level of detail info specified by the user.
base_shape(list): the shape of the basic element to be held by the
LoDTensor.
place(Place): CPU or GPU place indicating where the data in the new
......@@ -119,11 +124,11 @@ def create_random_int_lodtensor(lod, base_shape, place, low, high):
high(int): the upper bound of the random integers.
Returns:
A fluid LoDTensor object with tensor data and lod info.
A fluid LoDTensor object with tensor data and recursive_seq_lens info.
"""
assert isinstance(base_shape, list), "base_shape should be a list"
# append the total number of basic elements to the front of its shape
overall_shape = [sum(lod[-1])] + base_shape
overall_shape = [sum(recursive_seq_lens[-1])] + base_shape
# the range of integer data elements is [low, high]
data = np.random.random_integers(low, high, overall_shape).astype("int64")
return create_lod_tensor(data, lod, place)
return create_lod_tensor(data, recursive_seq_lens, place)
......@@ -1113,7 +1113,6 @@ class ModelAverage(Optimizer):
Args:
average_window_rate: The rate of average window.
params_grads: A list of parameter-grad variable pairs.
min_average_window: The minimum size of average window.
max_average_window: The maximum size of average window.
......@@ -1122,8 +1121,8 @@ class ModelAverage(Optimizer):
.. code-block:: python
optimizer = fluid.optimizer.Momentum()
_, params_grads = optimizer.minimize(cost)
model_average = fluid.optimizer.ModelAverage(params_grads, 0.15,
optimizer.minimize(cost)
model_average = fluid.optimizer.ModelAverage(0.15,
min_average_window=10000,
max_average_window=20000)
for pass_id in range(args.pass_num):
......@@ -1137,7 +1136,6 @@ class ModelAverage(Optimizer):
def __init__(self,
average_window_rate,
params_grads=None,
min_average_window=10000,
max_average_window=10000,
**kwargs):
......@@ -1146,21 +1144,16 @@ class ModelAverage(Optimizer):
self.min_average_window = min_average_window
self.max_average_window = max_average_window
self.params_grads = [] if params_grads is None else params_grads
params = {}
for param, grad in self.params_grads:
if param.do_model_average != False:
params[param.name] = (param, grad)
self.params_grads = []
for param in framework.default_main_program().global_block(
).all_parameters():
if param.name not in params and param.do_model_average != False:
if param.do_model_average != False:
grad = param.block.create_var(
name=unique_name.generate(".".join([param.name, 'tmp'])),
dtype=param.dtype,
persistable=False,
stop_gradient=True)
params[param.name] = (param, grad)
self.params_grads = params.values()
self.params_grads.append((param, grad))
for param, grad in self.params_grads:
self._append_average_accumulate_op(param)
......
......@@ -206,35 +206,35 @@ def infer(use_cuda, inference_program, params_dirname):
inferencer = fluid.Inferencer(
inference_program, param_path=params_dirname, place=place)
# Setup inputs by creating LoDTensors to represent sequences of words.
# Here each word is the basic element of these LoDTensors and the shape of
# Setup input by creating LoDTensor to represent sequence of words.
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[3, 4, 2]],
# which has only one lod level. Then the created LoDTensors will have only
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that lod info should be a list of lists.
lod = [[3, 4, 2]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_n2 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_n1 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_0 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_p1 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
ctx_p2 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=WORD_DICT_LEN - 1)
pred = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=PRED_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=PRED_DICT_LEN - 1)
mark = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=MARK_DICT_LEN - 1)
recursive_seq_lens, base_shape, place, low=0, high=MARK_DICT_LEN - 1)
results = inferencer.infer(
{
......
......@@ -229,11 +229,13 @@ def decode_main(use_cuda, is_sparse):
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_recursive_seq_lens = [1] * batch_size
init_recursive_seq_lens = [init_recursive_seq_lens, init_recursive_seq_lens]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
init_ids = fluid.create_lod_tensor(init_ids_data, init_recursive_seq_lens,
place)
init_scores = fluid.create_lod_tensor(init_scores_data,
init_recursive_seq_lens, place)
train_data = paddle.batch(
paddle.reader.shuffle(
......@@ -257,7 +259,7 @@ def decode_main(use_cuda, is_sparse):
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
print result_ids.lod()
print result_ids.recursive_sequence_lengths()
break
......
......@@ -209,13 +209,15 @@ def infer(use_cuda, inference_program, params_dirname):
inference_program, param_path=params_dirname, place=place)
# Use the first data from paddle.dataset.movielens.test() as input.
# Use create_lod_tensor(data, lod, place) API to generate LoD Tensor,
# where `data` is a list of sequences of index numbers, `lod` is
# the level of detail (lod) info associated with `data`.
# Use create_lod_tensor(data, recursive_sequence_lengths, place) API
# to generate LoD Tensor where `data` is a list of sequences of index
# numbers, `recursive_sequence_lengths` is the length-based level of detail
# (lod) info associated with `data`.
# For example, data = [[10, 2, 3], [2, 3]] means that it contains
# two sequences of indexes, of length 3 and 2, respectively.
# Correspondingly, lod = [[3, 2]] contains one level of detail info,
# indicating that `data` consists of two sequences of length 3 and 2.
# Correspondingly, recursive_sequence_lengths = [[3, 2]] contains one
# level of detail info, indicating that `data` consists of two sequences
# of length 3 and 2, respectively.
user_id = fluid.create_lod_tensor([[1]], [[1]], place)
gender_id = fluid.create_lod_tensor([[1]], [[1]], place)
age_id = fluid.create_lod_tensor([[0]], [[1]], place)
......
......@@ -128,17 +128,17 @@ def infer(use_cuda, inference_program, params_dirname=None):
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[3, 4, 2]],
# which has only one lod level. Then the created LoDTensor will have only
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that lod info should be a list of lists.
lod = [[3, 4, 2]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
tensor_words = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=len(word_dict) - 1)
recursive_seq_lens, base_shape, place, low=0, high=len(word_dict) - 1)
results = inferencer.infer({'words': tensor_words})
print("infer results: ", results)
......
......@@ -143,17 +143,17 @@ def infer(use_cuda, inference_program, params_dirname=None):
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[3, 4, 2]],
# which has only one lod level. Then the created LoDTensor will have only
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that lod info should be a list of lists.
lod = [[3, 4, 2]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
tensor_words = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=len(word_dict) - 1)
recursive_seq_lens, base_shape, place, low=0, high=len(word_dict) - 1)
results = inferencer.infer({'words': tensor_words})
print("infer results: ", results)
......
......@@ -138,17 +138,17 @@ def infer(use_cuda, inference_program, params_dirname=None):
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[3, 4, 2]],
# which has only one lod level. Then the created LoDTensor will have only
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that lod info should be a list of lists.
lod = [[3, 4, 2]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
tensor_words = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=len(word_dict) - 1)
recursive_seq_lens, base_shape, place, low=0, high=len(word_dict) - 1)
results = inferencer.infer({'words': tensor_words})
print("infer results: ", results)
......
......@@ -124,21 +124,22 @@ def infer(use_cuda, inference_program, params_dirname=None):
# Setup inputs by creating 4 LoDTensors representing 4 words. Here each word
# is simply an index to look up for the corresponding word vector and hence
# the shape of word (base_shape) should be [1]. The length-based level of
# detail (lod) info of each LoDtensor should be [[1]] meaning there is only
# one lod_level and there is only one sequence of one word on this level.
# Note that lod info should be a list of lists.
lod = [[1]]
# the shape of word (base_shape) should be [1]. The recursive_sequence_lengths,
# which is length-based level of detail (lod) of each LoDTensor, should be [[1]]
# meaning there is only one level of detail and there is only one sequence of
# one word on this level.
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[1]]
base_shape = [1]
# The range of random integers is [low, high]
first_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
second_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
third_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
fourth_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
result = inferencer.infer(
{
......
......@@ -238,17 +238,21 @@ def infer(word_dict, use_cuda, save_dirname=None):
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[3, 4, 2]],
# which has only one lod level. Then the created LoDTensor will have only
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that lod info should be a list of lists.
lod = [[3, 4, 2]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
tensor_words = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......@@ -257,7 +261,7 @@ def infer(word_dict, use_cuda, save_dirname=None):
feed={feed_target_names[0]: tensor_words},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
print(results[0].recursive_sequence_lengths())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
......
......@@ -247,35 +247,67 @@ def infer(use_cuda, save_dirname=None):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
# Setup inputs by creating LoDTensors to represent sequences of words.
# Here each word is the basic element of these LoDTensors and the shape of
# Setup input by creating LoDTensor to represent sequence of words.
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[3, 4, 2]],
# which has only one lod level. Then the created LoDTensors will have only
# Suppose the recursive_sequence_lengths info is set to [[3, 4, 2]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for three sentences of
# length 3, 4 and 2, respectively.
# Note that lod info should be a list of lists.
lod = [[3, 4, 2]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[3, 4, 2]]
base_shape = [1]
# The range of random integers is [low, high]
word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
pred = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=pred_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=pred_dict_len - 1)
ctx_n2 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
ctx_n1 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
ctx_0 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
ctx_p1 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
ctx_p2 = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=word_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=word_dict_len - 1)
mark = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=mark_dict_len - 1)
recursive_seq_lens,
base_shape,
place,
low=0,
high=mark_dict_len - 1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......@@ -301,7 +333,7 @@ def infer(use_cuda, save_dirname=None):
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
print(results[0].recursive_sequence_lengths())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
......
......@@ -108,7 +108,7 @@ def decoder_decode(context, is_sparse):
pre_state = pd.array_read(array=state_array, i=counter)
pre_score = pd.array_read(array=scores_array, i=counter)
# expand the lod of pre_state to be the same with pre_score
# expand the recursive_sequence_lengths of pre_state to be the same with pre_score
pre_state_expanded = pd.sequence_expand(pre_state, pre_score)
pre_ids_emb = pd.embedding(
......@@ -252,11 +252,13 @@ def decode_main(use_cuda, is_sparse):
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_lod = [1] * batch_size
init_lod = [init_lod, init_lod]
init_recursive_seq_lens = [1] * batch_size
init_recursive_seq_lens = [init_recursive_seq_lens, init_recursive_seq_lens]
init_ids = fluid.create_lod_tensor(init_ids_data, init_lod, place)
init_scores = fluid.create_lod_tensor(init_scores_data, init_lod, place)
init_ids = fluid.create_lod_tensor(init_ids_data, init_recursive_seq_lens,
place)
init_scores = fluid.create_lod_tensor(init_scores_data,
init_recursive_seq_lens, place)
train_data = paddle.batch(
paddle.reader.shuffle(
......@@ -280,7 +282,7 @@ def decode_main(use_cuda, is_sparse):
feed=feed_dict,
fetch_list=[translation_ids, translation_scores],
return_numpy=False)
print result_ids.lod()
print result_ids.recursive_sequence_lengths()
break
......
......@@ -260,13 +260,15 @@ def infer(use_cuda, save_dirname=None):
# Use the first data from paddle.dataset.movielens.test() as input
assert feed_target_names[0] == "user_id"
# Use create_lod_tensor(data, lod, place) API to generate LoD Tensor
# where `data` is a list of sequences of index numbers, `lod` is
# the level of detail (lod) info associated with `data`.
# Use create_lod_tensor(data, recursive_sequence_lengths, place) API
# to generate LoD Tensor where `data` is a list of sequences of index
# numbers, `recursive_sequence_lengths` is the length-based level of detail
# (lod) info associated with `data`.
# For example, data = [[10, 2, 3], [2, 3]] means that it contains
# two sequences of indexes, of length 3 and 2, respectively.
# Correspondingly, lod = [[3, 2]] contains one level of detail info,
# indicating that `data` consists of two sequences of length 3 and 2.
# Correspondingly, recursive_sequence_lengths = [[3, 2]] contains one
# level of detail info, indicating that `data` consists of two sequences
# of length 3 and 2, respectively.
user_id = fluid.create_lod_tensor([[1]], [[1]], place)
assert feed_target_names[1] == "gender_id"
......
......@@ -216,19 +216,19 @@ def infer(use_cuda, save_dirname=None):
# Here each word is the basic element of the LoDTensor and the shape of
# each word (base_shape) should be [1] since it is simply an index to
# look up for the corresponding word vector.
# Suppose the length_based level of detail (lod) info is set to [[4, 6]],
# which has only one lod level. Then the created LoDTensor will have only
# Suppose the recursive_sequence_lengths info is set to [[4, 6]],
# which has only one level of detail. Then the created LoDTensor will have only
# one higher level structure (sequence of words, or sentence) than the basic
# element (word). Hence the LoDTensor will hold data for two sentences of
# length 4 and 6, respectively.
# Note that lod info should be a list of lists.
lod = [[4, 6]]
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[4, 6]]
base_shape = [1]
# The range of random integers is [low, high]
word_data = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=1)
recursive_seq_lens, base_shape, place, low=0, high=1)
trg_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=1)
recursive_seq_lens, base_shape, place, low=0, high=1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
......@@ -241,7 +241,7 @@ def infer(use_cuda, save_dirname=None):
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
print(results[0].recursive_sequence_lengths())
np_data = np.array(results[0])
print("Inference shape: ", np_data.shape)
print("Inference results: ", np_data)
......
......@@ -168,21 +168,22 @@ def infer(use_cuda, save_dirname=None):
# Setup inputs by creating 4 LoDTensors representing 4 words. Here each word
# is simply an index to look up for the corresponding word vector and hence
# the shape of word (base_shape) should be [1]. The length-based level of
# detail (lod) info of each LoDtensor should be [[1]] meaning there is only
# one lod_level and there is only one sequence of one word on this level.
# Note that lod info should be a list of lists.
lod = [[1]]
# the shape of word (base_shape) should be [1]. The recursive_sequence_lengths,
# which is length-based level of detail (lod) of each LoDTensor, should be [[1]]
# meaning there is only one level of detail and there is only one sequence of
# one word on this level.
# Note that recursive_sequence_lengths should be a list of lists.
recursive_seq_lens = [[1]]
base_shape = [1]
# The range of random integers is [low, high]
first_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
second_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
third_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
fourth_word = fluid.create_random_int_lodtensor(
lod, base_shape, place, low=0, high=dict_size - 1)
recursive_seq_lens, base_shape, place, low=0, high=dict_size - 1)
assert feed_target_names[0] == 'firstw'
assert feed_target_names[1] == 'secondw'
......@@ -200,7 +201,7 @@ def infer(use_cuda, save_dirname=None):
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
print(results[0].recursive_sequence_lengths())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
......
......@@ -19,18 +19,21 @@ import unittest
class TestLoDTensor(unittest.TestCase):
def test_pybind_lod(self):
def test_pybind_recursive_seq_lens(self):
tensor = fluid.LoDTensor()
lod = []
tensor.set_recursive_sequence_lengths(lod)
lod = [[], [1], [3]]
self.assertRaises(Exception, tensor.set_recursive_sequence_lengths, lod)
lod = [[0], [2], [3]]
self.assertRaises(Exception, tensor.set_recursive_sequence_lengths, lod)
recursive_seq_lens = []
tensor.set_recursive_sequence_lengths(recursive_seq_lens)
recursive_seq_lens = [[], [1], [3]]
self.assertRaises(Exception, tensor.set_recursive_sequence_lengths,
recursive_seq_lens)
recursive_seq_lens = [[0], [2], [3]]
self.assertRaises(Exception, tensor.set_recursive_sequence_lengths,
recursive_seq_lens)
lod = [[1, 2, 3]]
tensor.set_recursive_sequence_lengths(lod)
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
recursive_seq_lens = [[1, 2, 3]]
tensor.set_recursive_sequence_lengths(recursive_seq_lens)
self.assertEqual(tensor.recursive_sequence_lengths(),
recursive_seq_lens)
tensor.set(np.random.random([6, 1]), fluid.CPUPlace())
self.assertTrue(tensor.has_valid_recursive_sequence_lengths())
tensor.set(np.random.random([9, 1]), fluid.CPUPlace())
......@@ -38,13 +41,14 @@ class TestLoDTensor(unittest.TestCase):
# Each level's sum should be equal to the number of items in the next level
# Moreover, last level's sum should be equal to the tensor height
lod = [[2, 3], [1, 3, 1, 2, 2]]
tensor.set_recursive_sequence_lengths(lod)
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
recursive_seq_lens = [[2, 3], [1, 3, 1, 2, 2]]
tensor.set_recursive_sequence_lengths(recursive_seq_lens)
self.assertEqual(tensor.recursive_sequence_lengths(),
recursive_seq_lens)
tensor.set(np.random.random([8, 1]), fluid.CPUPlace())
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())
lod = [[2, 3], [1, 3, 1, 2, 1]]
tensor.set_recursive_sequence_lengths(lod)
recursive_seq_lens = [[2, 3], [1, 3, 1, 2, 1]]
tensor.set_recursive_sequence_lengths(recursive_seq_lens)
self.assertTrue(tensor.has_valid_recursive_sequence_lengths())
tensor.set(np.random.random([9, 1]), fluid.CPUPlace())
self.assertFalse(tensor.has_valid_recursive_sequence_lengths())
......@@ -52,35 +56,42 @@ class TestLoDTensor(unittest.TestCase):
def test_create_lod_tensor(self):
# Create LoDTensor from a list
data = [[1, 2, 3], [3, 4]]
wrong_lod = [[2, 2]]
correct_lod = [[3, 2]]
self.assertRaises(AssertionError, create_lod_tensor, data, wrong_lod,
fluid.CPUPlace())
tensor = create_lod_tensor(data, correct_lod, fluid.CPUPlace())
self.assertEqual(tensor.recursive_sequence_lengths(), correct_lod)
wrong_recursive_seq_lens = [[2, 2]]
correct_recursive_seq_lens = [[3, 2]]
self.assertRaises(AssertionError, create_lod_tensor, data,
wrong_recursive_seq_lens, fluid.CPUPlace())
tensor = create_lod_tensor(data, correct_recursive_seq_lens,
fluid.CPUPlace())
self.assertEqual(tensor.recursive_sequence_lengths(),
correct_recursive_seq_lens)
# Create LoDTensor from numpy array
data = np.random.random([10, 1])
lod = [[2, 1], [3, 3, 4]]
tensor = create_lod_tensor(data, lod, fluid.CPUPlace())
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
recursive_seq_lens = [[2, 1], [3, 3, 4]]
tensor = create_lod_tensor(data, recursive_seq_lens, fluid.CPUPlace())
self.assertEqual(tensor.recursive_sequence_lengths(),
recursive_seq_lens)
# Create LoDTensor from another LoDTensor, they are differnt instances
new_lod = [[2, 2, 1], [1, 2, 2, 3, 2]]
new_tensor = create_lod_tensor(tensor, new_lod, fluid.CPUPlace())
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
self.assertEqual(new_tensor.recursive_sequence_lengths(), new_lod)
new_recursive_seq_lens = [[2, 2, 1], [1, 2, 2, 3, 2]]
new_tensor = create_lod_tensor(tensor, new_recursive_seq_lens,
fluid.CPUPlace())
self.assertEqual(tensor.recursive_sequence_lengths(),
recursive_seq_lens)
self.assertEqual(new_tensor.recursive_sequence_lengths(),
new_recursive_seq_lens)
def test_create_random_int_lodtensor(self):
# The shape of a word, commonly used in speech and NLP problem, is [1]
shape = [1]
lod = [[2, 3, 5]]
recursive_seq_lens = [[2, 3, 5]]
dict_size = 10000
low = 0
high = dict_size - 1
tensor = create_random_int_lodtensor(lod, shape,
tensor = create_random_int_lodtensor(recursive_seq_lens, shape,
fluid.CPUPlace(), low, high)
self.assertEqual(tensor.recursive_sequence_lengths(), lod)
self.assertEqual(tensor.recursive_sequence_lengths(),
recursive_seq_lens)
self.assertEqual(tensor.shape(), [10, 1])
......
......@@ -51,3 +51,4 @@ py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SERIAL)
py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL)
set_tests_properties(test_listen_and_serv_op PROPERTIES TIMEOUT 20)
set_tests_properties(test_dist_mnist PROPERTIES TIMEOUT 180)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import time
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import signal
SEED = 1
DTYPE = "float32"
paddle.dataset.mnist.fetch()
# random seed must set before configuring the network.
# fluid.default_startup_program().random_seed = SEED
def cnn_model(data):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=data,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
input=conv_pool_2,
size=SIZE,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)))
return predict
def get_model(batch_size):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Evaluator
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size_tensor)
inference_program = fluid.default_main_program().clone()
# Optimization
opt = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, beta1=0.9, beta2=0.999)
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
opt.minimize(avg_cost)
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
t = fluid.DistributeTranspiler()
t.transpile(
trainer_id=trainer_id,
program=main_program,
pservers=pserver_endpoints,
trainers=trainers)
return t
def run_pserver(pserver_endpoints, trainers, current_endpoint):
get_model(batch_size=20)
t = get_transpiler(0,
fluid.default_main_program(), pserver_endpoints,
trainers)
pserver_prog = t.get_pserver_program(current_endpoint)
startup_prog = t.get_startup_program(current_endpoint, pserver_prog)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
exe.run(pserver_prog)
class TestDistMnist(unittest.TestCase):
def setUp(self):
self._trainers = 1
self._pservers = 1
self._ps_endpoints = "127.0.0.1:9123"
def start_pserver(self, endpoint):
p = Process(
target=run_pserver,
args=(self._ps_endpoints, self._trainers, endpoint))
p.start()
return p.pid
def _wait_ps_ready(self, pid):
retry_times = 5
while True:
assert retry_times >= 0, "wait ps ready failed"
time.sleep(1)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
retry_times -= 1
def stop_pserver(self, pid):
os.kill(pid, signal.SIGTERM)
def test_with_place(self):
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
pserver_pid = self.start_pserver(self._ps_endpoints)
self._wait_ps_ready(pserver_pid)
self.run_trainer(p, 0)
self.stop_pserver(pserver_pid)
def run_trainer(self, place, trainer_id):
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model(
batch_size=20)
t = get_transpiler(trainer_id,
fluid.default_main_program(), self._ps_endpoints,
self._trainers)
trainer_prog = t.get_trainer_program()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
feed_var_list = [
var for var in trainer_prog.global_block().vars.itervalues()
if var.is_data
]
feeder = fluid.DataFeeder(feed_var_list, place)
for pass_id in xrange(10):
for batch_id, data in enumerate(train_reader()):
exe.run(trainer_prog, feed=feeder.feed(data))
if (batch_id + 1) % 10 == 0:
acc_set = []
avg_loss_set = []
for test_data in test_reader():
acc_np, avg_loss_np = exe.run(
program=test_program,
feed=feeder.feed(test_data),
fetch_list=[batch_acc, avg_cost])
acc_set.append(float(acc_np))
avg_loss_set.append(float(avg_loss_np))
# get test acc and loss
acc_val = np.array(acc_set).mean()
avg_loss_val = np.array(avg_loss_set).mean()
if float(acc_val
) > 0.8: # Smaller value to increase CI speed
return
else:
print(
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
format(pass_id, batch_id + 1,
float(avg_loss_val), float(acc_val)))
if math.isnan(float(avg_loss_val)):
assert ("got Nan loss, training failed.")
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid.core as core
import numpy
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.executor import Executor
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestFillZerosLikeOpForTensorArray(unittest.TestCase):
def place(self):
return core.CPUPlace()
def test_zero_filling_lod_tensor_array(self):
tensor = core.LoDTensor()
tensor.set(
numpy.arange(20).reshape(20, 1).astype('int32'), self.place())
tensor.set_lod([[0, 2, 5], [0, 3, 9, 11, 17, 20]])
expect = [
numpy.array(
[0, 0, 0, 0, 0], dtype='int32'), numpy.array(
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype='int32'),
numpy.array(
[0, 0, 0], dtype='int32')
]
lod = [[[0, 2, 5]], [[0, 6, 12]], [[0, 3]]]
self.main(
tensor=tensor,
expect_array=expect,
expect_lod=lod,
expect_max_len=3)
def main(self, tensor, expect_array, expect_lod, expect_max_len, level=0):
place = self.place()
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[10])
x.persistable = True
table = layers.lod_rank_table(x, level=level)
max_len = layers.max_sequence_len(table)
max_len.persistable = True
array = layers.lod_tensor_to_array(x, table)
array = layers.fill_zeros_like(array)
array.persistable = True
result = layers.array_to_lod_tensor(array, table)
result.persistable = True
exe = Executor(place)
scope = core.Scope()
exe.run(program, feed={'x': tensor}, scope=scope)
var = scope.find_var(array.name)
array = var.get_lod_tensor_array()
if expect_array is not None and expect_lod is not None:
self.check_array_same(array, expect_array, expect_lod)
self.assertEqual(
numpy.array(scope.find_var(max_len.name).get_tensor())[0],
expect_max_len)
def check_array_same(self, array, expect_tensor, expect_lod):
self.assertEqual(len(expect_tensor), len(array))
for i, exp in enumerate(zip(expect_tensor, expect_lod)):
exp_tensor, exp_lod = exp
exp_tensor = numpy.expand_dims(exp_tensor, axis=1)
self.assertTrue(numpy.allclose(exp_tensor, numpy.array(array[i])))
self.assertEqual(exp_lod, array[i].lod())
if __name__ == '__main__':
unittest.main()
......@@ -401,7 +401,7 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_maxout(self):
def test_crop(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 5], dtype="float32")
......@@ -410,6 +410,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_mean_iou(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[16], dtype='float32')
y = layers.data(name='label', shape=[1], dtype='int64')
iou = layers.mean_iou(x, y, 2)
self.assertIsNotNone(iou)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -315,7 +315,7 @@ class Trainer(object):
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
self.num_trainers = len(worker_endpoints)
current_endpoint = os.getenv("POD_IP") + ":" + port
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
worker_endpoints.remove(current_endpoint)
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
# in ParallelExecutor to start
......
......@@ -112,7 +112,7 @@ def fetch():
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5)
def convert(path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册