提交 796a448c 编写于 作者: T typhoonzero

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into refine_grpc_serde_code

......@@ -53,7 +53,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "v0.11"
GIT_TAG "v0.14"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
......
../../v2/build_and_install/paddleci.png
\ No newline at end of file
......@@ -125,12 +125,12 @@ Compile Time -> IR -> Runtime
## Operator/OpWithKernel/OpKernel
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/49caf1fb70820fb4a6c217634317c9306f361f36/op_op_with_kern_class_diagram.dot)
![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op_op_with_kern_class_diagram.dot)
---
## Operator
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot)
![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op.dot)
* `Operator` is the fundamental building block of the user interface.
* Operator stores input/output variable names and attributes.
......@@ -141,7 +141,7 @@ Compile Time -> IR -> Runtime
## OpWithKernel/Kernel
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/9d7f4eba185cf41c8e2fbfb40ae21890dbddcd39/op_with_kernel.dot)
![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op_with_kernel.dot)
* `OpWithKernel` inherits `Operator`.
* `OpWithKernel` contains a Kernel map.
......
digraph sample {
graph [rankdir=TD]; node [shape=record];
op [label="{Operator| InferShape()=0\lRun()=0\l | map<string, string[]> inputs_\lmap<string, string[]> outputs_ \l AttributeMap attrs_\l}"];
}
\ No newline at end of file
digraph sample {
graph [rankdir=TD]; node [shape=record];
op [label="{Operator| InferShape()=0\lRun()=0\l | map<string, string[]> inputs_\lmap<string, string[]> outputs_ \l AttributeMap attrs_\l}"];
op_with_kern [label="{OpWithKernel | InferShape()=0\lRun()\l | map<OpKernelKey,OpKernel>kernels_ }"]
op_kernel [label="{OpKernel | Compute()=0}"]
op_kernel_key [label="{OpKernelKey| Place place\n...}"]
op -> op_with_kern [dir=back, arrowtail=onormal]
op_with_kern -> op_kernel [arrowhead=vee, label="contains many"]
{
rank=same;
op_with_kern
op_kernel
}
op_kernel -> op_kernel_key [style=invis]
{
rank=same;
op_kernel
op_kernel_key
}
op_with_kern -> op_kernel_key [arrowhead=vee, label ="\nas map key"]
mul_op [label="MulOp"]
op_with_kern -> mul_op [dir=back, arrowtail=onormal]
mul_kernel [label="template <typename Place>\lclass MulOpKernel\l"]
op_kernel -> mul_kernel [dir=back, arrowtail=onormal]
mul_op -> mul_kernel [arrowhead=vee, label="register many"]
{
rank=same;
mul_op;
mul_kernel;
}
}
\ No newline at end of file
digraph sample {
graph [rankdir=TD]; node [shape=record];
op [label="{Operator}"];
op_with_kern [label="{OpWithKernel | InferShape()=0\lRun()\l | map<OpKernelKey,OpKernel>kernels_ }"]
op_kernel [label="{OpKernel | Compute()=0}"]
op_kernel_key [label="{OpKernelKey| Place place\n...}"]
op -> op_with_kern [dir=back, arrowtail=onormal]
op_with_kern -> op_kernel [arrowhead=vee, label="contains many"]
{
rank=same;
op_with_kern
op_kernel
}
op_kernel -> op_kernel_key [style=invis]
{
rank=same;
op_kernel
op_kernel_key
}
op_with_kern -> op_kernel_key [arrowhead=vee, label ="\nas map key"]
}
\ No newline at end of file
......@@ -142,7 +142,7 @@ gated_unit
-----------
.. autoclass:: paddle.v2.layer.gated_unit
:noindex:
Recurrent Layer Group
=====================
......@@ -354,7 +354,7 @@ dropout
--------
.. autoclass:: paddle.v2.layer.dropout
:noindex:
dot_prod
---------
.. autoclass:: paddle.v2.layer.dot_prod
......@@ -460,6 +460,11 @@ multi_binary_label_cross_entropy_cost
.. autoclass:: paddle.v2.layer.multi_binary_label_cross_entropy_cost
:noindex:
classification_cost
-------------------
.. autoclass:: paddle.v2.layer.classification_cost
:noindex:
huber_regression_cost
-------------------------
.. autoclass:: paddle.v2.layer.huber_regression_cost
......@@ -534,7 +539,7 @@ detection_output
----------------
.. autoclass:: paddle.v2.layer.detection_output
:noindex:
Check Layer
============
......
......@@ -5,7 +5,7 @@
充分展现英特尔平台的优势,有效提升PaddlePaddle在英特尔架构上的性能。
<div align="center">
<img src="image/overview.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/overview.png"><br/>
Figure 1. PaddlePaddle on IA
</div>
......@@ -42,16 +42,43 @@ Figure 1. PaddlePaddle on IA
MKL,MKLML以及MKL-DNN三者关系如下表:
| Name | Open Source | License | Descriptions |
| :---------- | :--------------- | :---------- | :------------ |
| MKL | No | Proprietary | Accelerate math processing routines |
| MKLML | No | Proprietary | Small package of MKL, especially for Machine Learning |
| MKL-DNN | Yes | Apache 2.0 | Accelerate primitives processing routines especially for Deep Neural Networks |
<table>
<thead>
<tr>
<th>Name</th>
<th>Open Source</th>
<th>License</th>
<th>Descriptions</th>
</tr>
</thead>
<tbody>
<tr>
<td>MKL</td>
<td>No</td>
<td>Proprietary</td>
<td>Accelerate math processing routines</td>
</tr>
<tr>
<td>MKLML</td>
<td>No</td>
<td>Proprietary</td>
<td>Small package of MKL, especially for Machine Learning</td>
</tr>
<tr>
<td>MKL-DNN</td>
<td>Yes</td>
<td>Apache 2.0</td>
<td>Accelerate primitives processing routines especially for Deep Neural Networks</td>
</tr>
</tbody>
</table>
MKLML可以与MKL-DNN共同使用,以此达到最好的性能。
<div align="center">
<img src="image/engine.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/engine.png"><br/>
Figure 2. PaddlePaddle with MKL Engines
</div>
......@@ -103,7 +130,7 @@ MKL-DNN的库目前只有动态库`libmkldnn.so`。
所以我们定义了一个`MKLDNNMatrix`用于管理MKL-DNN数据的不同格式以及相互之间的转换。
<div align="center">
<img src="image/matrix.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/matrix.png"><br/>
Figure 3. MKLDNNMatrix
</div>
......@@ -113,7 +140,7 @@ Figure 3. MKLDNNMatrix
子类只需要使用定义好的接口,实现具体的函数功能即可。
<div align="center">
<img src="image/layers.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/layers.png"><br/>
Figure 4. MKLDNNLayer
</div>
......@@ -150,7 +177,7 @@ Figure 4. MKLDNNLayer
所以整体上,在实现每个子类的时候就不需要关心分支的事情了。
<div align="center">
<img src="image/gradients.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/gradients.png"><br/>
Figure 5. Merge Gradients
</div>
......
digraph G{
subgraph cluster_timestep0 {
label="recurrent timestep i-1"
bgcolor=lightgray
node [style=filled,color=white]
fc0_0 [label="fc 0"]
fc0_1 [label="fc 1"]
fc0_2 [label="fc 2"]
fc0_0 -> fc0_1
fc0_1 -> fc0_2
}
subgraph cluster_timestep1 {
label="recurrent timestep i"
node [style=filled];
fc1_0 [label="fc 0"]
fc1_1 [label="fc 1"]
fc1_2 [label="fc 2"]
color=blue
fc1_0 -> fc1_1
fc1_1 -> fc1_2
}
subgraph cluster_timestep2 {
label="recurrent timestep i+1"
bgcolor=lightgray
node [style=filled,color=white]
fc2_0 [label="fc 0"]
fc2_1 [label="fc 1"]
fc2_2 [label="fc 2"]
fc2_0 -> fc2_1
fc2_1 -> fc2_2
}
fc0_1 -> fc1_1 [style="dotted" constraint=false]
fc1_1 -> fc2_1 [style="dotted" constraint=false]
}
\ No newline at end of file
digraph G{
subgraph cluster_timestep0 {
label="recurrent timestep i-1"
bgcolor=lightgray
node [style=filled,color=white]
fc0_0 [label="fc 0"]
fc0_1 [label="fc 1"]
fc0_2 [label="fc 2"]
m0 [label="memory"]
fc0_0 -> fc0_1
fc0_1 -> fc0_2
fc0_1 -> m0
m0 -> fc0_1
}
subgraph cluster_timestep1 {
label="recurrent timestep i"
node [style=filled];
fc1_0 [label="fc 0"]
fc1_1 [label="fc 1"]
fc1_2 [label="fc 2"]
m1 [label="memory"]
color=blue
fc1_0 -> fc1_1
fc1_1 -> fc1_2
fc1_1 -> m1
m1 -> fc1_1
}
subgraph cluster_timestep2 {
label="recurrent timestep i+1"
bgcolor=lightgray
node [style=filled,color=white]
fc2_0 [label="fc 0"]
fc2_1 [label="fc 1"]
fc2_2 [label="fc 2"]
m2 [label="memory"]
fc2_0 -> fc2_1
fc2_1 -> fc2_2
fc2_1 -> m2
m2 -> fc2_1
}
m0 -> m1 [style="dotted" constraint=false]
m1 -> m2 [style="dotted" constraint=false]
}
\ No newline at end of file
digraph G {
rankdir=LR;
subgraph cluster_t0 {
a [label="4"]
b [label="5"]
c [label="2"]
}
subgraph cluster_t1 {
d [label="0"]
e [label="9"]
}
subgraph cluster_t2 {
f [label="8"]
g [label="1"]
h [label="4"]
}
a -> b;
b -> c;
c -> d [constraint=false];
d -> e;
e -> f [constraint=false];
f -> g;
g -> h;
}
\ No newline at end of file
digraph G {
rankdir=LR;
a [label="4"]
b [label="5"]
c [label="2"]
d [label="0"]
e [label="9"]
f [label="8"]
g [label="1"]
h [label="4"]
a -> b;
b -> c;
c -> d;
d -> e;
e -> f;
f -> g;
g -> h;
}
\ No newline at end of file
......@@ -49,7 +49,9 @@ void FetchOpHandle::RunImpl() {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx);
if (var->generated_op_) {
var->generated_op_->Wait(cpu_ctx);
}
}
tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
......
......@@ -36,7 +36,9 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
......
......@@ -32,7 +32,9 @@ void SendOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
......
......@@ -20,7 +20,9 @@ if(NOT APPLE)
endif()
if(WITH_TESTING)
# both tests/book and analysis depends the models that generated by python/paddle/fluid/tests/book
add_subdirectory(tests/book)
add_subdirectory(analysis)
endif()
if (TENSORRT_FOUND)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle {
namespace inference {
namespace analysis {
size_t Dot::counter = 0;
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements some helper classes and methods for DOT programming
* support. It will give a visualization of the graph and that helps to debug
* the logics of each Pass.
*/
#pragma once
#include <glog/logging.h>
#include <sstream>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace inference {
namespace analysis {
/*
* A Dot template that helps to build a DOT graph definition.
*/
class Dot {
public:
static size_t counter;
struct Attr {
std::string key;
std::string value;
Attr(const std::string& key, const std::string& value)
: key(key), value(value) {}
std::string repr() const {
std::stringstream ss;
ss << key << "=" << '"' << value << '"';
return ss.str();
}
};
struct Node {
std::string name;
std::vector<Attr> attrs;
Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name),
attrs(attrs),
id_("node_" + std::to_string(Dot::counter++)) {}
std::string id() const { return id_; }
std::string repr() const {
std::stringstream ss;
CHECK(!name.empty());
ss << id_;
for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) {
ss << "[label=" << '"' << name << '"' << " ";
}
ss << attrs[i].repr();
ss << ((i < attrs.size() - 1) ? " " : "]");
}
return ss.str();
}
private:
std::string id_;
};
struct Edge {
std::string source;
std::string target;
std::vector<Attr> attrs;
Edge(const std::string& source, const std::string& target,
const std::vector<Attr>& attrs)
: source(source), target(target), attrs(attrs) {}
std::string repr() const {
std::stringstream ss;
CHECK(!source.empty());
CHECK(!target.empty());
ss << source << "->" << target;
for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) {
ss << "[";
}
ss << attrs[i].repr();
ss << ((i < attrs.size() - 1) ? " " : "]");
}
return ss.str();
}
};
Dot() = default;
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& name, const std::vector<Attr>& attrs) {
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'";
nodes_.emplace(name, Node{name, attrs});
}
void AddEdge(const std::string& source, const std::string& target,
const std::vector<Attr>& attrs) {
CHECK(!source.empty());
CHECK(!target.empty());
auto sid = nodes_.at(source).id();
auto tid = nodes_.at(target).id();
edges_.emplace_back(sid, tid, attrs);
}
// Compile to DOT language codes.
std::string Build() const {
std::stringstream ss;
const std::string indent = " ";
ss << "digraph G {" << '\n';
// Add graph attrs
for (const auto& attr : attrs_) {
ss << indent << attr.repr() << '\n';
}
// add nodes
for (auto& item : nodes_) {
ss << indent << item.second.repr() << '\n';
}
// add edges
for (auto& edge : edges_) {
ss << indent << edge.repr() << '\n';
}
ss << "} // end G";
return ss.str();
}
private:
std::unordered_map<std::string, Node> nodes_;
std::vector<Edge> edges_;
std::vector<Attr> attrs_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -19,6 +19,9 @@ limitations under the License. */
namespace paddle {
namespace inference {
struct Buffer;
enum class DeviceType { UNK = -1, CPU, GPU };
/*
* EngineBase is the base class of all inference engines. An inference engine
* takes a paddle program as input, and outputs the result in fluid Tensor
......@@ -45,8 +48,20 @@ class EngineBase {
// Execute the engine, that will run the inference network.
virtual void Execute(int batch_size) = 0;
// Return the IO buffer that allocated in engine. One can read/write directly
// on the buffer. If the buffer's buffer is nullptr, one can also allocate
// memory and maintain it outside the engine.
virtual Buffer& buffer(const std::string& name) = 0;
virtual ~EngineBase() {}
}; // class EngineBase
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
int max_size; // buffer allocated space.
int size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
} // namespace inference
} // namespace paddle
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(convert)
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
......@@ -30,16 +30,24 @@ void TensorRTEngine::Build(const DescType& paddle_model) {
}
void TensorRTEngine::Execute(int batch_size) {
infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
std::vector<void*> buffers;
for (auto& buf : buffers_) {
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
PADDLE_ENFORCE_GT(buf.max_size, 0);
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buffers.push_back(buf.buffer);
}
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
cudaStreamSynchronize(*stream_);
}
TensorRTEngine::~TensorRTEngine() {
// clean buffer
for (auto& buffer : buffers_) {
if (buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
buffer = nullptr;
for (auto& buf : buffers_) {
if (buf.buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
buf.buffer = nullptr;
buf.max_size = 0;
}
}
}
......@@ -59,7 +67,7 @@ void TensorRTEngine::FreezeNetwork() {
infer_context_.reset(infer_engine_->createExecutionContext());
// allocate GPU buffers.
buffers_.resize(buffer_sizes_.size(), nullptr);
buffers_.resize(buffer_sizes_.size());
for (auto& item : buffer_sizes_) {
if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
......@@ -67,7 +75,11 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
}
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
buf.size = buf.max_size = item.second;
buf.device = DeviceType::GPU;
}
}
......@@ -113,7 +125,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
}
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name);
return buffer(name).buffer;
}
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
......@@ -123,11 +135,13 @@ void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, it->second,
cudaMemcpyDeviceToHost, *stream_));
}
void*& TensorRTEngine::buffer(const std::string& name) {
Buffer& TensorRTEngine::buffer(const std::string& name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
......@@ -137,10 +151,12 @@ void*& TensorRTEngine::buffer(const std::string& name) {
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
size_t size) {
void* buf = buffer(name);
cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_);
PADDLE_ENFORCE_EQ(
0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyHostToDevice, *stream_));
}
void TensorRTEngine::SetITensor(const std::string& name,
......
......@@ -87,7 +87,9 @@ class TensorRTEngine : public EngineBase {
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
void*& buffer(const std::string& name);
Buffer& buffer(const std::string& name) override;
cudaStream_t* stream() { return stream_; }
// Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size);
......@@ -116,7 +118,7 @@ class TensorRTEngine : public EngineBase {
cudaStream_t* stream_;
nvinfer1::ILogger& logger_;
std::vector<void*> buffers_;
std::vector<Buffer> buffers_;
// max data size for the buffers.
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
......
......@@ -77,6 +77,37 @@ TEST_F(TensorRTEngineTest, add_layer) {
ASSERT_EQ(y_cpu, x_v * 2 + 3);
}
TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
// Weight in CPU memory.
// It seems tensorrt FC use col-major: [[1.0, 3.3], [1.1, 4.4]]
// instead of row-major, which is [[1.0, 1.1], [3.3, 4.4]]
float raw_weight[4] = {1.0, 1.1, 3.3, 4.4};
float raw_bias[2] = {1.3, 2.4};
TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 4);
TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 2);
auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 2, 1});
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, 2,
weight.get(), bias.get());
PADDLE_ENFORCE(fc_layer != nullptr);
engine_->DeclareOutput(fc_layer, 0, "y");
engine_->FreezeNetwork();
ASSERT_EQ(engine_->engine()->getNbBindings(), 2);
float x_v[2] = {1.0, 2.0};
engine_->SetInputFromCPU("x", reinterpret_cast<void*>(&x_v),
2 * sizeof(float));
engine_->Execute(1);
LOG(INFO) << "to get output";
float y_cpu[2] = {-1., -1.};
engine_->GetOutputInCPU("y", &y_cpu[0], sizeof(float) * 2);
ASSERT_EQ(y_cpu[0], 4.5);
ASSERT_EQ(y_cpu[1], 14.5);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -36,5 +36,5 @@ inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp conv)
inference_test(recommender_system)
#inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment ARGS conv)
#inference_test(understand_sentiment ARGS conv)
inference_test(word2vec)
......@@ -187,7 +187,8 @@ class GemmConvKernel : public framework::OpKernel<T> {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(filter_slice, col_matrix, &out_slice);
blas.MatMul(filter_slice, false, col_matrix, false, T(1.0), &out_slice,
T(0.0));
}
}
}
......@@ -304,7 +305,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(in_grad_slice);
col_matrix.Resize(col_matrix_shape);
}
blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix);
blas.MatMul(filter_slice, true, out_grad_slice, false, T(1.0),
&col_matrix, T(0.0));
if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides,
......@@ -351,8 +353,8 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(out_grad_slice, false, col_matrix, true,
&filter_grad_slice);
blas.MatMul(out_grad_slice, false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
}
......
......@@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
blas.MatMul(filter, true, input_batch, false, &col_matrix);
blas.MatMul(filter, true, input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
if (data_dim == 2U) {
// col2im: col_matrix -> dy
......@@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w)
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
blas.MatMul(filter, false, col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
}
if (filter_grad) {
// input batch
......@@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w)
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
blas.MatMul(in_batch, false, col_matrix, true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
}
}
}
......
......@@ -46,19 +46,6 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
framework::LoDTensor cpu_tensor;
cpu_tensor.ShareDataWith(*tensor);
cpu_tensor.set_lod(tensor->lod());
// reset tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
}
}
};
......
......@@ -63,6 +63,7 @@ __device__ T reduceSum(T val, int tid, int len) {
val += platform::CudaShuffleDownSync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
......
......@@ -463,7 +463,7 @@ void SetProfileListener() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int64_t>::max());
1, std::numeric_limits<std::mt19937::result_type>::max());
profiler_lister_id = dist6(rng);
}
int64_t ListenerId() { return profiler_lister_id; }
......
......@@ -398,7 +398,7 @@ function gen_dockerfile() {
cat <<EOF
========================================
Generate /paddle/build/Dockerfile ...
Generate ${PADDLE_ROOT}/build/Dockerfile ...
========================================
EOF
......@@ -422,7 +422,7 @@ EOF
CMD='"true"'
fi
cat >> /paddle/build/Dockerfile <<EOF
cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF
ADD python/dist/*.whl /
# run paddle version to install python packages first
RUN apt-get update &&\
......@@ -436,8 +436,14 @@ EOF
${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_GPU_ENV}
ENV NCCL_LAUNCH_MODE PARALLEL
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
EOF
if [[ ${WITH_GOLANG:-OFF} == "ON" ]]; then
cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF
ADD go/cmd/pserver/pserver /usr/bin/
ADD go/cmd/master/master /usr/bin/
EOF
fi
cat >> ${PADDLE_ROOT}/build/Dockerfile <<EOF
# default command shows the paddle version and exit
CMD [${CMD}]
EOF
......
......@@ -32,7 +32,7 @@ function start_build_docker() {
DOCKER_ENV=$(cat <<EOL
-e FLAGS_fraction_of_gpu_memory_to_use=0.15 \
-e CTEST_OUTPUT_ON_FAILURE=1 \
-e CTEST_PARALLEL_LEVEL=5 \
-e CTEST_PARALLEL_LEVEL=1 \
-e APT_MIRROR=${apt_mirror} \
-e WITH_GPU=ON \
-e CUDA_ARCH_NAME=Auto \
......
......@@ -96,7 +96,7 @@ def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
TOTAL_ENG_WORDS))
TOTAL_EN_WORDS))
return src_dict_size, trg_dict_size
......
......@@ -299,14 +299,18 @@ class Executor(object):
if feed is None:
feed = {}
if not isinstance(feed, dict):
raise TypeError("feed should be a map")
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" %
(type(feed)))
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()
if not isinstance(program, Program):
raise TypeError()
raise TypeError(
"Executor requires Program as its Parameter. But you passed in %s"
% (type(program)))
if scope is None:
scope = global_scope()
......
......@@ -47,6 +47,8 @@ class Optimizer(object):
raise TypeError("learning rate should be float or Variable")
self.regularization = regularization
self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss
self._dtype = None
# each program should have a independent learning rate
# program -> Variable(learning_rate)
self._learning_rate_map = dict()
......@@ -77,7 +79,7 @@ class Optimizer(object):
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype='float32',
dtype='float32' if self._dtype == None else self._dtype,
persistable=True)
def global_learning_rate(self, program=None):
......@@ -200,6 +202,7 @@ class Optimizer(object):
# Create any accumulators
program = loss.block.program
self._dtype = loss.dtype
with program_guard(program, startup_program):
global_block = framework.default_main_program().global_block()
start = len(global_block.ops)
......@@ -391,7 +394,7 @@ class AdamOptimizer(Optimizer):
beta_shape = [1]
self._beta1_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta1_pow_acc'),
dtype='float32',
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
......@@ -400,7 +403,7 @@ class AdamOptimizer(Optimizer):
self._beta2_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta2_pow_acc'),
dtype='float32',
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
......@@ -493,7 +496,7 @@ class AdamaxOptimizer(Optimizer):
beta_shape = [1]
self._beta1_pow_acc = self.helper.create_global_variable(
name=unique_name.generate('beta1_pow_acc'),
dtype='float32',
dtype='float32' if self._dtype == None else self._dtype,
shape=beta_shape,
lod_level=0,
persistable=True)
......@@ -900,8 +903,10 @@ class ModelAverage(Optimizer):
# param = (sum_1 + sum_2 + sum_3) / (num_accumulates + old_num_accumulates)
tmp = layers.sum(x=[num_accumulates, old_num_accumulates])
sum = layers.sum(x=[sum_1, sum_2, sum_3])
tmp = layers.cast(x=tmp, dtype='float32')
sum = layers.cast(x=sum, dtype='float32')
tmp = layers.cast(
x=tmp, dtype='float32' if self._dtype == None else self._dtype)
sum = layers.cast(
x=sum, dtype='float32' if self._dtype == None else self._dtype)
layers.elementwise_div(x=sum, y=tmp, out=param)
def _add_average_restore_op(self, block, param_grad):
......
......@@ -36,7 +36,7 @@ depth = 8
mix_hidden_lr = 1e-3
IS_SPARSE = True
PASS_NUM = 100
PASS_NUM = 10
BATCH_SIZE = 10
embedding_name = 'emb'
......
......@@ -18,7 +18,7 @@ import unittest
import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.memory_optimization_transpiler import memory_optimize
from paddle.fluid.transpiler import memory_optimize
class TestControlFlowGraph(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册