提交 40b8b634 编写于 作者: L Luo Tao

Merge branch 'develop' into refine_relu_test

......@@ -5,7 +5,7 @@
充分展现英特尔平台的优势,有效提升PaddlePaddle在英特尔架构上的性能。
<div align="center">
<img src="image/overview.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/overview.png"><br/>
Figure 1. PaddlePaddle on IA
</div>
......@@ -42,16 +42,43 @@ Figure 1. PaddlePaddle on IA
MKL,MKLML以及MKL-DNN三者关系如下表:
| Name | Open Source | License | Descriptions |
| :---------- | :--------------- | :---------- | :------------ |
| MKL | No | Proprietary | Accelerate math processing routines |
| MKLML | No | Proprietary | Small package of MKL, especially for Machine Learning |
| MKL-DNN | Yes | Apache 2.0 | Accelerate primitives processing routines especially for Deep Neural Networks |
<table>
<thead>
<tr>
<th>Name</th>
<th>Open Source</th>
<th>License</th>
<th>Descriptions</th>
</tr>
</thead>
<tbody>
<tr>
<td>MKL</td>
<td>No</td>
<td>Proprietary</td>
<td>Accelerate math processing routines</td>
</tr>
<tr>
<td>MKLML</td>
<td>No</td>
<td>Proprietary</td>
<td>Small package of MKL, especially for Machine Learning</td>
</tr>
<tr>
<td>MKL-DNN</td>
<td>Yes</td>
<td>Apache 2.0</td>
<td>Accelerate primitives processing routines especially for Deep Neural Networks</td>
</tr>
</tbody>
</table>
MKLML可以与MKL-DNN共同使用,以此达到最好的性能。
<div align="center">
<img src="image/engine.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/engine.png"><br/>
Figure 2. PaddlePaddle with MKL Engines
</div>
......@@ -103,7 +130,7 @@ MKL-DNN的库目前只有动态库`libmkldnn.so`。
所以我们定义了一个`MKLDNNMatrix`用于管理MKL-DNN数据的不同格式以及相互之间的转换。
<div align="center">
<img src="image/matrix.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/matrix.png"><br/>
Figure 3. MKLDNNMatrix
</div>
......@@ -113,7 +140,7 @@ Figure 3. MKLDNNMatrix
子类只需要使用定义好的接口,实现具体的函数功能即可。
<div align="center">
<img src="image/layers.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/layers.png"><br/>
Figure 4. MKLDNNLayer
</div>
......@@ -150,7 +177,7 @@ Figure 4. MKLDNNLayer
所以整体上,在实现每个子类的时候就不需要关心分支的事情了。
<div align="center">
<img src="image/gradients.png"><br/>
<img src="https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/v2/images/gradients.png"><br/>
Figure 5. Merge Gradients
</div>
......
digraph G{
subgraph cluster_timestep0 {
label="recurrent timestep i-1"
bgcolor=lightgray
node [style=filled,color=white]
fc0_0 [label="fc 0"]
fc0_1 [label="fc 1"]
fc0_2 [label="fc 2"]
fc0_0 -> fc0_1
fc0_1 -> fc0_2
}
subgraph cluster_timestep1 {
label="recurrent timestep i"
node [style=filled];
fc1_0 [label="fc 0"]
fc1_1 [label="fc 1"]
fc1_2 [label="fc 2"]
color=blue
fc1_0 -> fc1_1
fc1_1 -> fc1_2
}
subgraph cluster_timestep2 {
label="recurrent timestep i+1"
bgcolor=lightgray
node [style=filled,color=white]
fc2_0 [label="fc 0"]
fc2_1 [label="fc 1"]
fc2_2 [label="fc 2"]
fc2_0 -> fc2_1
fc2_1 -> fc2_2
}
fc0_1 -> fc1_1 [style="dotted" constraint=false]
fc1_1 -> fc2_1 [style="dotted" constraint=false]
}
\ No newline at end of file
digraph G{
subgraph cluster_timestep0 {
label="recurrent timestep i-1"
bgcolor=lightgray
node [style=filled,color=white]
fc0_0 [label="fc 0"]
fc0_1 [label="fc 1"]
fc0_2 [label="fc 2"]
m0 [label="memory"]
fc0_0 -> fc0_1
fc0_1 -> fc0_2
fc0_1 -> m0
m0 -> fc0_1
}
subgraph cluster_timestep1 {
label="recurrent timestep i"
node [style=filled];
fc1_0 [label="fc 0"]
fc1_1 [label="fc 1"]
fc1_2 [label="fc 2"]
m1 [label="memory"]
color=blue
fc1_0 -> fc1_1
fc1_1 -> fc1_2
fc1_1 -> m1
m1 -> fc1_1
}
subgraph cluster_timestep2 {
label="recurrent timestep i+1"
bgcolor=lightgray
node [style=filled,color=white]
fc2_0 [label="fc 0"]
fc2_1 [label="fc 1"]
fc2_2 [label="fc 2"]
m2 [label="memory"]
fc2_0 -> fc2_1
fc2_1 -> fc2_2
fc2_1 -> m2
m2 -> fc2_1
}
m0 -> m1 [style="dotted" constraint=false]
m1 -> m2 [style="dotted" constraint=false]
}
\ No newline at end of file
digraph G {
rankdir=LR;
subgraph cluster_t0 {
a [label="4"]
b [label="5"]
c [label="2"]
}
subgraph cluster_t1 {
d [label="0"]
e [label="9"]
}
subgraph cluster_t2 {
f [label="8"]
g [label="1"]
h [label="4"]
}
a -> b;
b -> c;
c -> d [constraint=false];
d -> e;
e -> f [constraint=false];
f -> g;
g -> h;
}
\ No newline at end of file
digraph G {
rankdir=LR;
a [label="4"]
b [label="5"]
c [label="2"]
d [label="0"]
e [label="9"]
f [label="8"]
g [label="1"]
h [label="4"]
a -> b;
b -> c;
c -> d;
d -> e;
e -> f;
f -> g;
g -> h;
}
\ No newline at end of file
......@@ -49,7 +49,9 @@ void FetchOpHandle::RunImpl() {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(cpu_ctx);
if (var->generated_op_) {
var->generated_op_->Wait(cpu_ctx);
}
}
tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
......
......@@ -36,7 +36,9 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
......
......@@ -32,7 +32,9 @@ void SendOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
......
......@@ -21,6 +21,8 @@ endif()
if(WITH_TESTING)
add_subdirectory(tests/book)
# analysis test depends the models that generate by python/paddle/fluid/tests/book
add_subdirectory(analysis)
endif()
if (TENSORRT_FOUND)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle {
namespace inference {
namespace analysis {
size_t Dot::counter = 0;
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file implements some helper classes and methods for DOT programming
* support. It will give a visualization of the graph and that helps to debug
* the logics of each Pass.
*/
#pragma once
#include <glog/logging.h>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace inference {
namespace analysis {
/*
* A Dot template that helps to build a DOT graph definition.
*/
class Dot {
public:
static size_t counter;
struct Attr {
std::string key;
std::string value;
Attr(const std::string& key, const std::string& value)
: key(key), value(value) {}
std::string repr() const {
std::stringstream ss;
ss << key << "=" << '"' << value << '"';
return ss.str();
}
};
struct Node {
std::string name;
std::vector<Attr> attrs;
Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name),
attrs(attrs),
id_("node_" + std::to_string(Dot::counter++)) {}
std::string id() const { return id_; }
std::string repr() const {
std::stringstream ss;
CHECK(!name.empty());
ss << id_;
for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) {
ss << "[label=" << '"' << name << '"' << " ";
}
ss << attrs[i].repr();
ss << ((i < attrs.size() - 1) ? " " : "]");
}
return ss.str();
}
private:
std::string id_;
};
struct Edge {
std::string source;
std::string target;
std::vector<Attr> attrs;
Edge(const std::string& source, const std::string& target,
const std::vector<Attr>& attrs)
: source(source), target(target), attrs(attrs) {}
std::string repr() const {
std::stringstream ss;
CHECK(!source.empty());
CHECK(!target.empty());
ss << source << "->" << target;
for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) {
ss << "[";
}
ss << attrs[i].repr();
ss << ((i < attrs.size() - 1) ? " " : "]");
}
return ss.str();
}
};
Dot() = default;
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& name, const std::vector<Attr>& attrs) {
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'";
nodes_.emplace(name, Node{name, attrs});
}
void AddEdge(const std::string& source, const std::string& target,
const std::vector<Attr>& attrs) {
CHECK(!source.empty());
CHECK(!target.empty());
auto sid = nodes_.at(source).id();
auto tid = nodes_.at(target).id();
edges_.emplace_back(sid, tid, attrs);
}
// Compile to DOT language codes.
std::string Build() const {
std::stringstream ss;
const std::string indent = " ";
ss << "digraph G {" << '\n';
// Add graph attrs
for (const auto& attr : attrs_) {
ss << indent << attr.repr() << '\n';
}
// add nodes
for (auto& item : nodes_) {
ss << indent << item.second.repr() << '\n';
}
// add edges
for (auto& edge : edges_) {
ss << indent << edge.repr() << '\n';
}
ss << "} // end G";
return ss.str();
}
private:
std::unordered_map<std::string, Node> nodes_;
std::vector<Edge> edges_;
std::vector<Attr> attrs_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -14,11 +14,15 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
namespace inference {
struct Buffer;
enum class DeviceType { UNK = -1, CPU, GPU };
/*
* EngineBase is the base class of all inference engines. An inference engine
* takes a paddle program as input, and outputs the result in fluid Tensor
......@@ -45,8 +49,20 @@ class EngineBase {
// Execute the engine, that will run the inference network.
virtual void Execute(int batch_size) = 0;
// Return the IO buffer that allocated in engine. One can read/write directly
// on the buffer. If the buffer's buffer is nullptr, one can also allocate
// memory and maintain it outside the engine.
virtual Buffer& buffer(const std::string& name) = 0;
virtual ~EngineBase() {}
}; // class EngineBase
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
int max_size; // buffer allocated space.
int size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
} // namespace inference
} // namespace paddle
nv_library(tensorrt_engine SRCS engine.cc)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(convert)
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc io_converter.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
......@@ -30,16 +30,24 @@ void TensorRTEngine::Build(const DescType& paddle_model) {
}
void TensorRTEngine::Execute(int batch_size) {
infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
std::vector<void*> buffers;
for (auto& buf : buffers_) {
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
PADDLE_ENFORCE_GT(buf.max_size, 0);
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buffers.push_back(buf.buffer);
}
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
cudaStreamSynchronize(*stream_);
}
TensorRTEngine::~TensorRTEngine() {
// clean buffer
for (auto& buffer : buffers_) {
if (buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
buffer = nullptr;
for (auto& buf : buffers_) {
if (buf.buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
buf.buffer = nullptr;
buf.max_size = 0;
}
}
}
......@@ -59,7 +67,7 @@ void TensorRTEngine::FreezeNetwork() {
infer_context_.reset(infer_engine_->createExecutionContext());
// allocate GPU buffers.
buffers_.resize(buffer_sizes_.size(), nullptr);
buffers_.resize(buffer_sizes_.size());
for (auto& item : buffer_sizes_) {
if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
......@@ -67,7 +75,11 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
}
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
buf.size = buf.max_size = item.second;
buf.device = DeviceType::GPU;
}
}
......@@ -113,7 +125,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
}
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name);
return buffer(name).buffer;
}
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
......@@ -123,11 +135,13 @@ void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, it->second,
cudaMemcpyDeviceToHost, *stream_));
}
void*& TensorRTEngine::buffer(const std::string& name) {
Buffer& TensorRTEngine::buffer(const std::string& name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
......@@ -137,9 +151,12 @@ void*& TensorRTEngine::buffer(const std::string& name) {
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
size_t size) {
void* buf = buffer(name);
PADDLE_ENFORCE_EQ(
0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyHostToDevice, *stream_));
}
void TensorRTEngine::SetITensor(const std::string& name,
......
......@@ -87,7 +87,9 @@ class TensorRTEngine : public EngineBase {
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
void*& buffer(const std::string& name);
Buffer& buffer(const std::string& name) override;
cudaStream_t* stream() { return stream_; }
// Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size);
......@@ -116,7 +118,7 @@ class TensorRTEngine : public EngineBase {
cudaStream_t* stream_;
nvinfer1::ILogger& logger_;
std::vector<void*> buffers_;
std::vector<Buffer> buffers_;
// max data size for the buffers.
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
......
......@@ -46,19 +46,6 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, *dev_ctx);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
framework::LoDTensor cpu_tensor;
cpu_tensor.ShareDataWith(*tensor);
cpu_tensor.set_lod(tensor->lod());
// reset tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
}
}
};
......
......@@ -463,7 +463,7 @@ void SetProfileListener() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int64_t>::max());
1, std::numeric_limits<std::mt19937::result_type>::max());
profiler_lister_id = dist6(rng);
}
int64_t ListenerId() { return profiler_lister_id; }
......
......@@ -96,7 +96,7 @@ def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
TOTAL_ENG_WORDS))
TOTAL_EN_WORDS))
return src_dict_size, trg_dict_size
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import numpy as np
import unittest
import paddle.fluid as fluid
......@@ -243,7 +243,7 @@ class TestParallelExecutorBase(unittest.TestCase):
begin = time.time()
first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name])
first_loss = numpy.array(first_loss)
first_loss = np.array(first_loss)
for i in xrange(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[])
......@@ -256,7 +256,7 @@ class TestParallelExecutorBase(unittest.TestCase):
print "%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin))
last_loss = numpy.array(last_loss)
last_loss = np.array(last_loss)
print first_loss, last_loss
# self.assertGreater(first_loss[0], last_loss[0])
......@@ -284,8 +284,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_network_convergence(simple_fc_net)
self.check_network_convergence(simple_fc_net, allow_op_delay=True)
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
simple_fc_net, feed_dict={"image": img,
"label": label})
......@@ -294,8 +294,8 @@ class TestMNIST(TestParallelExecutorBase):
self.check_simple_fc_convergence()
def check_simple_fc_parallel_accuracy(self):
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
single_first_loss, single_last_loss = self.check_network_convergence(
method=simple_fc_net,
seed=1000,
......@@ -319,8 +319,8 @@ class TestMNIST(TestParallelExecutorBase):
def check_batchnorm_fc_convergence(self):
self.check_network_convergence(fc_with_batchnorm)
img = numpy.zeros(shape=[32, 784], dtype='float32')
label = numpy.ones(shape=[32, 1], dtype='int64')
img = np.zeros(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
self.check_network_convergence(
fc_with_batchnorm, feed_dict={"image": img,
"label": label})
......@@ -404,9 +404,6 @@ class ModelHyperParams(object):
dropout = 0.1
import numpy as np
def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Pad the instances to the max sequence length in batch, and generate the
......@@ -533,9 +530,8 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
opt.minimize(loss)
batch_size = 32
image = numpy.random.normal(size=(batch_size,
784)).astype('float32')
label = numpy.random.randint(0, 10, (batch_size, 1), dtype="int64")
image = np.random.normal(size=(batch_size, 784)).astype('float32')
label = np.random.randint(0, 10, (batch_size, 1), dtype="int64")
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......@@ -552,12 +548,12 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict)
test_loss = numpy.array(test_loss)
test_loss = np.array(test_loss)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
train_loss = numpy.array(train_loss)
train_loss = np.array(train_loss)
self.assertTrue(
numpy.allclose(
np.allclose(
train_loss, test_loss, atol=1e-8),
"Train loss: " + str(train_loss) + "\n Test loss:" +
str(test_loss))
......@@ -712,7 +708,7 @@ class TestCRFModel(unittest.TestCase):
data = train_data()
for i in xrange(10):
cur_batch = next(data)
print map(numpy.array,
print map(np.array,
pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0]
......@@ -721,3 +717,84 @@ class TestCRFModel(unittest.TestCase):
def test_update_dense_parameter(self):
self.check_network_convergence(is_sparse=False)
# test fetch all the variables of global_block
import paddle.dataset.flowers as flowers
import math
def Lenet(data, class_dim):
conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None)
bn1 = fluid.layers.batch_norm(conv1, act='relu')
pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2)
conv2 = fluid.layers.conv2d(pool1, 50, 5, 1, act=None)
bn2 = fluid.layers.batch_norm(conv2, act='relu')
pool2 = fluid.layers.pool2d(bn2, 2, 'max', 2)
fc1 = fluid.layers.fc(pool2, size=500, act='relu')
fc2 = fluid.layers.fc(fc1, size=class_dim, act='softmax')
return fc2
class TestFetchOp(unittest.TestCase):
def parallel_exe(self, train_inputs, seed):
main = fluid.Program()
startup = fluid.Program()
startup.random_seed = seed
with fluid.program_guard(main, startup):
data = fluid.layers.data(
name='image', shape=[3, 224, 224], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = Lenet(data, class_dim=102)
loss = fluid.layers.cross_entropy(input=out, label=label)
loss = fluid.layers.mean(loss)
opt = fluid.optimizer.Momentum(
learning_rate=0.1,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
opt.minimize(loss)
# TODO(zcd): I found that onece the memory optimizer is open,
# parallel_exe doesn't fetch some variable, such as conv2d_0.b_0@GRAD,
# conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)
feeder = fluid.DataFeeder(place=place, feed_list=[data, label])
pe = fluid.ParallelExecutor(
use_cuda=True, loss_name=loss.name, main_program=main)
fetch_list = []
all_vars = main.global_block().vars
for k, v in all_vars.iteritems():
if 'tmp' not in k and k[0] is not '_' or v.persistable:
fetch_list.append(k)
for data in train_inputs:
ret = pe.run(fetch_list, feed=feeder.feed(data))
for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i]))
def test_update_sparse_parameter(self):
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
tst_reader_iter = tst_reader()
iters = 3
train_inputs = []
for i in range(iters):
train_inputs.append(tst_reader_iter.next())
self.parallel_exe(train_inputs, seed=1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册