提交 77200a70 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into...

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_double_buffer_for_cpp_reader
...@@ -56,7 +56,7 @@ script: ...@@ -56,7 +56,7 @@ script:
export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/master/scripts/deploy/deploy_docs.sh export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/master/scripts/deploy/deploy_docs.sh
export DOCS_DIR=`pwd` export DOCS_DIR=`pwd`
cd .. cd ..
curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH $DOCS_DIR $DOCS_DIR/build/doc curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH $DOCS_DIR $DOCS_DIR/build/doc/v2
notifications: notifications:
email: email:
on_success: change on_success: change
......
...@@ -144,6 +144,8 @@ include(external/eigen) # download eigen3 ...@@ -144,6 +144,8 @@ include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11 include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/grpc) include(external/grpc)
include(external/snappy) # download snappy
include(external/snappystream)
include(cudnn) # set cudnn libraries, must before configure include(cudnn) # set cudnn libraries, must before configure
include(cupti) include(cupti)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -138,13 +138,14 @@ def main(): ...@@ -138,13 +138,14 @@ def main():
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
# Evaluator # Evaluator
accuracy = fluid.evaluator.Accuracy(input=predict, label=label) batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size)
# inference program # inference program
inference_program = fluid.default_main_program().clone() inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program): with fluid.program_guard(inference_program):
test_target = accuracy.metrics + accuracy.states inference_program = fluid.io.get_inference_program(batch_acc)
inference_program = fluid.io.get_inference_program(test_target)
# Optimization # Optimization
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
...@@ -157,27 +158,30 @@ def main(): ...@@ -157,27 +158,30 @@ def main():
# test # test
def test(exe): def test(exe):
accuracy.reset(exe) test_pass_acc = fluid.average.WeightedAverage()
for batch_id, data in enumerate(test_reader()): for batch_id, data in enumerate(test_reader()):
img_data = np.array(map(lambda x: x[0].reshape(data_shape), img_data = np.array(map(lambda x: x[0].reshape(data_shape),
data)).astype("float32") data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1]) y_data = y_data.reshape([-1, 1])
exe.run(inference_program, outs = exe.run(inference_program,
feed={"pixel": img_data, feed={"pixel": img_data,
"label": y_data}) "label": y_data},
fetch_list=[batch_acc, batch_size])
test_pass_acc.add(value=np.array(outs[0]), weight=np.array(outs[1]))
return accuracy.eval(exe) return test_pass_acc.eval()
def train_loop(exe, trainer_prog): def train_loop(exe, trainer_prog):
iters = 0 iters = 0
ts = time.time() ts = time.time()
train_pass_acc = fluid.average.WeightedAverage()
for pass_id in range(args.num_passes): for pass_id in range(args.num_passes):
# train # train
start_time = time.time() start_time = time.time()
num_samples = 0 num_samples = 0
accuracy.reset(exe) train_pass_acc.reset()
with profiler.profiler("CPU", 'total') as prof: with profiler.profiler("CPU", 'total') as prof:
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
ts = time.time() ts = time.time()
...@@ -187,13 +191,14 @@ def main(): ...@@ -187,13 +191,14 @@ def main():
y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = np.array(map(lambda x: x[1], data)).astype("int64")
y_data = y_data.reshape([-1, 1]) y_data = y_data.reshape([-1, 1])
loss, acc = exe.run( loss, acc, b_size = exe.run(
trainer_prog, trainer_prog,
feed={"pixel": img_data, feed={"pixel": img_data,
"label": y_data}, "label": y_data},
fetch_list=[avg_cost] + accuracy.metrics) fetch_list=[avg_cost, batch_acc, batch_size])
iters += 1 iters += 1
num_samples += len(data) num_samples += len(data)
train_pass_acc.add(value=acc, weight=b_size)
print( print(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s"
% (pass_id, iters, loss, acc, % (pass_id, iters, loss, acc,
...@@ -201,7 +206,7 @@ def main(): ...@@ -201,7 +206,7 @@ def main():
) # The accuracy is the accumulation of batches, but not the current batch. ) # The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed = time.time() - start_time pass_elapsed = time.time() - start_time
pass_train_acc = accuracy.eval(exe) pass_train_acc = train_pass_acc.eval()
pass_test_acc = test(exe) pass_test_acc = test(exe)
print( print(
"Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n" "Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n"
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF(MOBILE_INFERENCE)
return()
ENDIF()
include (ExternalProject)
# NOTE: snappy is needed when linking with recordio
SET(SNAPPY_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy)
SET(SNAPPY_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy)
SET(SNAPPY_INCLUDE_DIR "${SNAPPY_INSTALL_DIR}/include/" CACHE PATH "snappy include directory." FORCE)
ExternalProject_Add(
extern_snappy
GIT_REPOSITORY "https://github.com/google/snappy"
GIT_TAG "1.1.7"
PREFIX ${SNAPPY_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DSNAPPY_BUILD_TESTS:BOOL=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
BUILD_COMMAND make -j8
INSTALL_COMMAND make install
)
add_library(snappy STATIC IMPORTED GLOBAL)
set_property(TARGET snappy PROPERTY IMPORTED_LOCATION
"${SNAPPY_INSTALL_DIR}/lib/libsnappy.a")
include_directories(${SNAPPY_INCLUDE_DIR})
add_dependencies(snappy extern_snappy)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
IF(MOBILE_INFERENCE)
return()
ENDIF()
include (ExternalProject)
# NOTE: snappy is needed when linking with recordio
SET(SNAPPYSTREAM_SOURCES_DIR ${THIRD_PARTY_PATH}/snappy_stream)
SET(SNAPPYSTREAM_INSTALL_DIR ${THIRD_PARTY_PATH}/install/snappy_stream)
SET(SNAPPYSTREAM_INCLUDE_DIR "${SNAPPYSTREAM_INSTALL_DIR}/include/" CACHE PATH "snappy stream include directory." FORCE)
ExternalProject_Add(
extern_snappystream
GIT_REPOSITORY "https://github.com/hoxnox/snappystream.git"
GIT_TAG "0.2.8"
PREFIX ${SNAPPYSTREAM_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_INSTALL_PREFIX=${SNAPPY_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${SNAPPY_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DSNAPPY_ROOT=${SNAPPY_INSTALL_DIR}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${SNAPPYSTREAM_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPYSTREAM_INSTALL_DIR}/lib
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
BUILD_COMMAND make -j8
INSTALL_COMMAND make install
DEPENDS snappy
)
add_library(snappystream STATIC IMPORTED GLOBAL)
set_property(TARGET snappystream PROPERTY IMPORTED_LOCATION
"${SNAPPYSTREAM_INSTALL_DIR}/lib/libsnappystream.a")
include_directories(${SNAPPYSTREAM_INCLUDE_DIR})
add_dependencies(snappystream extern_snappystream)
add_subdirectory(api)
add_subdirectory(v2) add_subdirectory(v2)
...@@ -47,3 +47,5 @@ sphinx_add_target(paddle_docs_cn ...@@ -47,3 +47,5 @@ sphinx_add_target(paddle_docs_cn
${SPHINX_CACHE_DIR_CN} ${SPHINX_CACHE_DIR_CN}
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_CN}) ${SPHINX_HTML_DIR_CN})
add_subdirectory(api)
...@@ -8,7 +8,7 @@ set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees") ...@@ -8,7 +8,7 @@ set(SPHINX_CACHE_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/_doctrees")
set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html") set(SPHINX_HTML_DIR_EN "${CMAKE_CURRENT_BINARY_DIR}/en/html")
configure_file( configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/../templates/conf.py.en.in" "${CMAKE_CURRENT_SOURCE_DIR}/../../templates/conf.py.en.in"
"${BINARY_BUILD_DIR_EN}/conf.py" "${BINARY_BUILD_DIR_EN}/conf.py"
@ONLY) @ONLY)
......
...@@ -5,7 +5,7 @@ API ...@@ -5,7 +5,7 @@ API
:maxdepth: 1 :maxdepth: 1
overview.rst overview.rst
v2/model_configs.rst model_configs.rst
v2/data.rst data.rst
v2/run_logic.rst run_logic.rst
fluid/index.rst fluid/index.rst
...@@ -5,3 +5,4 @@ add_subdirectory(operators) ...@@ -5,3 +5,4 @@ add_subdirectory(operators)
add_subdirectory(pybind) add_subdirectory(pybind)
add_subdirectory(inference) add_subdirectory(inference)
add_subdirectory(string) add_subdirectory(string)
add_subdirectory(recordio)
...@@ -5,14 +5,14 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) ...@@ -5,14 +5,14 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
if (WITH_GPU) if(WITH_GPU)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto) nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto)
else() else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto) cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto)
endif () endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
if (WITH_GPU) if(WITH_GPU)
nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor)
else() else()
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
...@@ -39,8 +39,13 @@ cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) ...@@ -39,8 +39,13 @@ cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry init math_function) DEPS operator op_registry init math_function)
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor) if(WITH_GPU)
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform) nv_library(data_type_transform SRCS data_type_transform.cu DEPS tensor)
nv_test(data_type_transform_test SRCS data_type_transform_test.cc data_type_transform_test.cu DEPS data_type_transform)
else()
cc_library(data_type_transform SRCS data_type_transform.cc DEPS tensor)
cc_test(data_type_transform_test SRCS data_type_transform_test.cc DEPS data_type_transform)
endif()
cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function) cc_library(data_layout_transform SRCS data_layout_transform.cc DEPS tensor math_function)
cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform) cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_layout_transform)
......
...@@ -28,24 +28,19 @@ class Channel { ...@@ -28,24 +28,19 @@ class Channel {
virtual bool Send(T*) = 0; virtual bool Send(T*) = 0;
virtual bool Receive(T*) = 0; virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0; virtual size_t Cap() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual ~Channel() {} virtual ~Channel() {}
}; };
// Forward declaration of channel implementations. // Forward declaration of channel implementations.
namespace details {
template <typename T> template <typename T>
class Buffered; class ChannelImpl;
template <typename T>
class UnBuffered;
} // namespace details
template <typename T> template <typename T>
Channel<T>* MakeChannel(size_t buffer_size) { Channel<T>* MakeChannel(size_t buffer_size) {
if (buffer_size > 0) { return new ChannelImpl<T>(buffer_size);
return new details::Buffered<T>(buffer_size);
}
return new details::UnBuffered<T>();
} }
template <typename T> template <typename T>
...@@ -89,6 +84,19 @@ class ChannelHolder { ...@@ -89,6 +84,19 @@ class ChannelHolder {
if (IsInitialized()) holder_->Close(); if (IsInitialized()) holder_->Close();
} }
size_t Cap() {
if (IsInitialized()) return holder_->Cap();
return -1;
}
void Lock() {
if (IsInitialized()) holder_->Lock();
}
void Unlock() {
if (IsInitialized()) holder_->Unlock();
}
inline bool IsInitialized() const { return holder_ != nullptr; } inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() { inline const std::type_index Type() {
...@@ -106,6 +114,9 @@ class ChannelHolder { ...@@ -106,6 +114,9 @@ class ChannelHolder {
virtual const std::type_index Type() const = 0; virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0; virtual void* Ptr() const = 0;
virtual void Close() = 0; virtual void Close() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual size_t Cap() = 0;
}; };
template <typename T> template <typename T>
...@@ -115,11 +126,28 @@ class ChannelHolder { ...@@ -115,11 +126,28 @@ class ChannelHolder {
} }
virtual const std::type_index Type() const { return type_; } virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); } virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual void Close() { virtual void Close() {
if (channel_) channel_->Close(); if (channel_) channel_->Close();
} }
virtual size_t Cap() {
if (channel_)
return channel_->Cap();
else
return -1;
}
virtual void Lock() {
if (channel_) channel_->Lock();
}
virtual void Unlock() {
if (channel_) channel_->Unlock();
}
std::unique_ptr<Channel<T>> channel_; std::unique_ptr<Channel<T>> channel_;
const std::type_index type_; const std::type_index type_;
}; };
...@@ -131,5 +159,4 @@ class ChannelHolder { ...@@ -131,5 +159,4 @@ class ChannelHolder {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include "paddle/fluid/framework/details/buffered_channel.h" #include "paddle/fluid/framework/channel_impl.h"
#include "paddle/fluid/framework/details/unbuffered_channel.h"
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <atomic>
#include <condition_variable>
#include <deque>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T>
class ChannelImpl : public paddle::framework::Channel<T> {
friend Channel<T> *paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T> *);
public:
virtual bool Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
virtual void Unlock();
virtual void Close();
ChannelImpl(size_t);
virtual ~ChannelImpl();
private:
struct QueueMessage {
T *data;
std::condition_variable_any cond;
bool chan_closed = false;
bool completed = false;
QueueMessage(T *item) : data(item) {}
void Wait(std::unique_lock<std::recursive_mutex> &lock) {
cond.wait(lock, [this]() { return completed; });
}
void Notify() {
completed = true;
cond.notify_all();
}
};
bool send_return(bool value) {
send_ctr--;
destructor_cond_.notify_all();
return value;
}
bool recv_return(bool value) {
recv_ctr--;
destructor_cond_.notify_all();
return value;
}
size_t cap_;
std::recursive_mutex mu_;
bool closed_;
std::deque<T> buf_;
std::deque<std::shared_ptr<QueueMessage>> recvq;
std::deque<std::shared_ptr<QueueMessage>> sendq;
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
std::condition_variable_any destructor_cond_;
};
template <typename T>
ChannelImpl<T>::ChannelImpl(size_t capacity)
: cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) {
PADDLE_ENFORCE_GE(capacity, 0);
}
template <typename T>
bool ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed, do nothing
if (closed_) {
lock.unlock();
// TODO(abhinavarora) Should panic on closed channel
return send_return(false);
}
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
// Do the data transfer
*(m->data) = std::move(*item);
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return send_return(true);
}
// Unbuffered channel will always bypass this
// If buffered channel has space in buffer,
// write the element to the buffer.
if (buf_.size() < cap_) {
// Copy to buffer
buf_.push_back(std::move(*item));
// Release lock and return true
lock.unlock();
return send_return(true);
}
// Block on channel, because some receiver will complete
// the operation for us
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel
return send_return(!m->chan_closed);
}
template <typename T>
bool ChannelImpl<T>::Receive(T *item) {
recv_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed and buffer is empty or
// channel is unbuffered
if (closed_ && buf_.empty()) {
lock.unlock();
return recv_return(false);
}
// If there is a sender, directly receive the value we want
// from the sender, bypassing the channel buffer if any
if (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
// Do the data transfer
*item = std::move(*(m->data));
// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return recv_return(true);
}
// If this is a buffered channel and there are items in buffer
if (buf_.size() > 0) {
// Directly read from buffer
*item = std::move(buf_.front());
buf_.pop_front();
// Release lock and return true
lock.unlock();
return recv_return(true);
}
// No sender available, block on this channel
// Some receiver will complete the option for us
auto m = std::make_shared<QueueMessage>(item);
recvq.push_back(m);
m->Wait(lock);
return recv_return(!m->chan_closed);
}
template <typename T>
void ChannelImpl<T>::Lock() {
mu_.lock();
}
template <typename T>
void ChannelImpl<T>::Unlock() {
mu_.unlock();
}
template <typename T>
void ChannelImpl<T>::Close() {
std::unique_lock<std::recursive_mutex> lock{mu_};
if (closed_) {
// TODO(abhinavarora): closing an already closed channel should panic
lock.unlock();
return;
}
closed_ = true;
// Empty the readers
while (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
m->chan_closed = true;
m->Notify();
}
// Empty the senders
while (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
m->chan_closed = true;
m->Notify();
}
}
template <typename T>
ChannelImpl<T>::~ChannelImpl() {
Close();
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
std::unique_lock<std::recursive_mutex> lock{mu_};
destructor_cond_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
} // namespace framework
} // namespace paddle
...@@ -23,8 +23,19 @@ using paddle::framework::Channel; ...@@ -23,8 +23,19 @@ using paddle::framework::Channel;
using paddle::framework::ChannelHolder; using paddle::framework::ChannelHolder;
using paddle::framework::MakeChannel; using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel; using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered; TEST(Channel, ChannelCapacityTest) {
const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size);
EXPECT_EQ(ch->Cap(), buffer_size);
CloseChannel(ch);
delete ch;
ch = MakeChannel<size_t>(0);
EXPECT_EQ(ch->Cap(), 0U);
CloseChannel(ch);
delete ch;
}
void RecevingOrderEqualToSendingOrder(Channel<int> *ch) { void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
unsigned sum_send = 0; unsigned sum_send = 0;
...@@ -35,38 +46,17 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) { ...@@ -35,38 +46,17 @@ void RecevingOrderEqualToSendingOrder(Channel<int> *ch) {
} }
}); });
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
int recv; int recv = 999;
EXPECT_EQ(ch->Receive(&recv), true); EXPECT_EQ(ch->Receive(&recv), true);
EXPECT_EQ(recv, i); EXPECT_EQ(recv, i);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(200));
CloseChannel(ch); CloseChannel(ch);
t.join(); t.join();
EXPECT_EQ(sum_send, 10U); EXPECT_EQ(sum_send, 10U);
delete ch; delete ch;
} }
TEST(Channel, MakeAndClose) {
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
{
// MakeChannel should return a buffered channel is buffer_size > 0.
auto ch = MakeChannel<int>(10);
EXPECT_NE(dynamic_cast<Buffered<int> *>(ch), nullptr);
EXPECT_EQ(dynamic_cast<UnBuffered<int> *>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
{
// MakeChannel should return an un-buffered channel is buffer_size = 0.
auto ch = MakeChannel<int>(0);
EXPECT_EQ(dynamic_cast<Buffered<int> *>(ch), nullptr);
EXPECT_NE(dynamic_cast<UnBuffered<int> *>(ch), nullptr);
CloseChannel(ch);
delete ch;
}
}
TEST(Channel, SufficientBufferSizeDoesntBlock) { TEST(Channel, SufficientBufferSizeDoesntBlock) {
const size_t buffer_size = 10; const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
...@@ -166,7 +156,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) { ...@@ -166,7 +156,6 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
const size_t buffer_size = 10; const size_t buffer_size = 10;
auto ch = MakeChannel<size_t>(buffer_size); auto ch = MakeChannel<size_t>(buffer_size);
size_t sum = 0;
std::thread t([&]() { std::thread t([&]() {
// Try to write more than buffer size. // Try to write more than buffer size.
for (size_t i = 0; i < 2 * buffer_size; ++i) { for (size_t i = 0; i < 2 * buffer_size; ++i) {
...@@ -174,12 +163,9 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) { ...@@ -174,12 +163,9 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations EXPECT_EQ(ch->Send(&i), true); // should block after 10 iterations
else else
EXPECT_EQ(ch->Send(&i), false); EXPECT_EQ(ch->Send(&i), false);
sum += i;
} }
}); });
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
EXPECT_EQ(sum, 45U);
CloseChannel(ch); CloseChannel(ch);
t.join(); t.join();
delete ch; delete ch;
...@@ -211,7 +197,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { ...@@ -211,7 +197,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
}, },
&thread_ended[i]); &thread_ended[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all the threads are blocked // Verify that all the threads are blocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -222,7 +208,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { ...@@ -222,7 +208,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
// This should unblock all receivers // This should unblock all receivers
CloseChannel(ch); CloseChannel(ch);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -232,10 +218,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) { ...@@ -232,10 +218,7 @@ void ChannelCloseUnblocksReceiversTest(Channel<int> *ch) {
for (size_t i = 0; i < num_threads; i++) t[i].join(); for (size_t i = 0; i < num_threads; i++) t[i].join();
} }
void ChannelCloseUnblocksSendersTest(Channel<int> *ch) { void ChannelCloseUnblocksSendersTest(Channel<int> *ch, bool isBuffered) {
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
size_t num_threads = 5; size_t num_threads = 5;
std::thread t[num_threads]; std::thread t[num_threads];
bool thread_ended[num_threads]; bool thread_ended[num_threads];
...@@ -253,9 +236,9 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) { ...@@ -253,9 +236,9 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) {
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
if (dynamic_cast<Buffered<int> *>(ch)) { if (isBuffered) {
// If ch is Buffered, atleast 4 threads must be blocked. // If ch is Buffered, atleast 4 threads must be blocked.
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -272,14 +255,14 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) { ...@@ -272,14 +255,14 @@ void ChannelCloseUnblocksSendersTest(Channel<int> *ch) {
// This should unblock all senders // This should unblock all senders
CloseChannel(ch); CloseChannel(ch);
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], true); EXPECT_EQ(thread_ended[i], true);
} }
if (dynamic_cast<Buffered<int> *>(ch)) { if (isBuffered) {
// Verify that only 1 send was successful // Verify that only 1 send was successful
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -304,7 +287,7 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) { ...@@ -304,7 +287,7 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
// any senders waiting for channel to have write space // any senders waiting for channel to have write space
TEST(Channel, BufferedChannelCloseUnblocksSendersTest) { TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
auto ch = MakeChannel<int>(1); auto ch = MakeChannel<int>(1);
ChannelCloseUnblocksSendersTest(ch); ChannelCloseUnblocksSendersTest(ch, true);
delete ch; delete ch;
} }
...@@ -320,7 +303,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) { ...@@ -320,7 +303,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) {
// unblocks any senders waiting for senders // unblocks any senders waiting for senders
TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) { TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) {
auto ch = MakeChannel<int>(0); auto ch = MakeChannel<int>(0);
ChannelCloseUnblocksReceiversTest(ch); ChannelCloseUnblocksSendersTest(ch, false);
delete ch; delete ch;
} }
...@@ -342,7 +325,7 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) { ...@@ -342,7 +325,7 @@ TEST(Channel, UnbufferedLessReceiveMoreSendTest) {
ch->Receive(&recv); ch->Receive(&recv);
EXPECT_EQ(recv, i); EXPECT_EQ(recv, i);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
EXPECT_EQ(sum_send, 3U); EXPECT_EQ(sum_send, 3U);
CloseChannel(ch); CloseChannel(ch);
...@@ -368,7 +351,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { ...@@ -368,7 +351,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
ch->Send(&i); ch->Send(&i);
sum_send += i; sum_send += i;
} }
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
EXPECT_EQ(sum_send, 10U); EXPECT_EQ(sum_send, 10U);
EXPECT_EQ(sum_receive, 10U); EXPECT_EQ(sum_receive, 10U);
// send three more elements // send three more elements
...@@ -386,7 +369,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) { ...@@ -386,7 +369,7 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
// This tests that destroying a channel unblocks // This tests that destroying a channel unblocks
// any senders waiting for channel to have write space // any senders waiting for channel to have write space
void ChannelDestroyUnblockSenders(Channel<int> *ch) { void ChannelDestroyUnblockSenders(Channel<int> *ch, bool isBuffered) {
size_t num_threads = 5; size_t num_threads = 5;
std::thread t[num_threads]; std::thread t[num_threads];
bool thread_ended[num_threads]; bool thread_ended[num_threads];
...@@ -405,11 +388,9 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) { ...@@ -405,11 +388,9 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) {
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
bool is_buffered_channel = false;
if (dynamic_cast<Buffered<int> *>(ch)) is_buffered_channel = true;
if (is_buffered_channel) { if (isBuffered) {
// If channel is buffered, verify that atleast 4 threads are blocked // If channel is buffered, verify that atleast 4 threads are blocked
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -432,13 +413,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) { ...@@ -432,13 +413,13 @@ void ChannelDestroyUnblockSenders(Channel<int> *ch) {
EXPECT_EQ(thread_ended[i], true); EXPECT_EQ(thread_ended[i], true);
} }
// Count number of successfuld sends // Count number of successful sends
int ct = 0; int ct = 0;
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
if (send_success[i]) ct++; if (send_success[i]) ct++;
} }
if (is_buffered_channel) { if (isBuffered) {
// Only 1 send must be successful // Only 1 send must be successful
EXPECT_EQ(ct, 1); EXPECT_EQ(ct, 1);
} else { } else {
...@@ -495,7 +476,7 @@ TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) { ...@@ -495,7 +476,7 @@ TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) {
TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) { TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) {
size_t buffer_size = 1; size_t buffer_size = 1;
auto ch = MakeChannel<int>(buffer_size); auto ch = MakeChannel<int>(buffer_size);
ChannelDestroyUnblockSenders(ch); ChannelDestroyUnblockSenders(ch, true);
} }
// This tests that destroying an unbuffered channel also unblocks // This tests that destroying an unbuffered channel also unblocks
...@@ -507,7 +488,20 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) { ...@@ -507,7 +488,20 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) {
TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) { TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) {
auto ch = MakeChannel<int>(0); auto ch = MakeChannel<int>(0);
ChannelDestroyUnblockSenders(ch); ChannelDestroyUnblockSenders(ch, false);
}
TEST(ChannelHolder, ChannelHolderCapacityTest) {
const size_t buffer_size = 10;
ChannelHolder *ch = new ChannelHolder();
ch->Reset<int>(buffer_size);
EXPECT_EQ(ch->Cap(), buffer_size);
delete ch;
ch = new ChannelHolder();
ch->Reset<int>(0);
EXPECT_EQ(ch->Cap(), 0U);
delete ch;
} }
void ChannelHolderSendReceive(ChannelHolder *ch) { void ChannelHolderSendReceive(ChannelHolder *ch) {
...@@ -641,7 +635,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { ...@@ -641,7 +635,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
}, },
&thread_ended[i]); &thread_ended[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all the threads are blocked // Verify that all the threads are blocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -652,7 +646,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { ...@@ -652,7 +646,7 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
// This should unblock all receivers // This should unblock all receivers
ch->close(); ch->close();
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -663,9 +657,6 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) { ...@@ -663,9 +657,6 @@ void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
} }
void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
using paddle::framework::details::Buffered;
using paddle::framework::details::UnBuffered;
size_t num_threads = 5; size_t num_threads = 5;
std::thread t[num_threads]; std::thread t[num_threads];
bool thread_ended[num_threads]; bool thread_ended[num_threads];
...@@ -683,7 +674,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { ...@@ -683,7 +674,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
}, },
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
if (isBuffered) { if (isBuffered) {
// If ch is Buffered, atleast 4 threads must be blocked. // If ch is Buffered, atleast 4 threads must be blocked.
...@@ -702,7 +693,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) { ...@@ -702,7 +693,7 @@ void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
// This should unblock all senders // This should unblock all senders
ch->close(); ch->close();
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads got unblocked // Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
...@@ -775,7 +766,7 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) { ...@@ -775,7 +766,7 @@ void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) {
&thread_ended[i], &send_success[i]); &thread_ended[i], &send_success[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait 0.2 sec
if (isBuffered) { if (isBuffered) {
// If channel is buffered, verify that atleast 4 threads are blocked // If channel is buffered, verify that atleast 4 threads are blocked
int ct = 0; int ct = 0;
...@@ -836,7 +827,7 @@ void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) { ...@@ -836,7 +827,7 @@ void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) {
}, },
&thread_ended[i]); &thread_ended[i]);
} }
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
// Verify that all threads are blocked // Verify that all threads are blocked
for (size_t i = 0; i < num_threads; i++) { for (size_t i = 0; i < num_threads; i++) {
......
...@@ -42,6 +42,7 @@ void DataTransform(const OpKernelType& expected_kernel_type, ...@@ -42,6 +42,7 @@ void DataTransform(const OpKernelType& expected_kernel_type,
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
// do data type transform
if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) { if (expected_kernel_type.data_type_ != kernel_type_for_var.data_type_) {
TransDataType(kernel_type_for_var, expected_kernel_type, in, &out); TransDataType(kernel_type_for_var, expected_kernel_type, in, &out);
transformed = true; transformed = true;
......
...@@ -16,13 +16,16 @@ limitations under the License. */ ...@@ -16,13 +16,16 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
inline proto::VarType::Type ToDataType(std::type_index type) { inline proto::VarType::Type ToDataType(std::type_index type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
if (typeid(float).hash_code() == type.hash_code()) { if (typeid(platform::float16).hash_code() == type.hash_code()) {
return proto::VarType::FP16;
} else if (typeid(float).hash_code() == type.hash_code()) {
return proto::VarType::FP32; return proto::VarType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) { } else if (typeid(double).hash_code() == type.hash_code()) {
return proto::VarType::FP64; return proto::VarType::FP64;
...@@ -40,6 +43,8 @@ inline proto::VarType::Type ToDataType(std::type_index type) { ...@@ -40,6 +43,8 @@ inline proto::VarType::Type ToDataType(std::type_index type) {
inline std::type_index ToTypeIndex(proto::VarType::Type type) { inline std::type_index ToTypeIndex(proto::VarType::Type type) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case proto::VarType::FP16:
return typeid(platform::float16);
case proto::VarType::FP32: case proto::VarType::FP32:
return typeid(float); return typeid(float);
case proto::VarType::FP64: case proto::VarType::FP64:
...@@ -59,6 +64,9 @@ template <typename Visitor> ...@@ -59,6 +64,9 @@ template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
using namespace paddle::framework::proto; using namespace paddle::framework::proto;
switch (type) { switch (type) {
case proto::VarType::FP16:
visitor.template operator()<platform::float16>();
break;
case proto::VarType::FP32: case proto::VarType::FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
......
...@@ -47,9 +47,15 @@ struct CastDataType { ...@@ -47,9 +47,15 @@ struct CastDataType {
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_); auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin, trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>()); CastDataTypeFunctor<InType, OutType>());
#ifdef __NVCC__
} else if (platform::is_gpu_place(in_.place())) {
platform::Transform<platform::CUDADeviceContext> trans;
auto* context = static_cast<const platform::CUDADeviceContext*>(ctx_);
trans(*context, in_begin, in_end, out_begin,
CastDataTypeFunctor<InType, OutType>());
#endif
} else { } else {
// TODO(dzhwinter): enhance Copy CPU<->GPU with different data type? PADDLE_THROW("Unsupported place!");
PADDLE_THROW("Unsupport CPU <-> GPU!");
} }
} }
}; };
...@@ -65,6 +71,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var, ...@@ -65,6 +71,10 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
auto ctx = pool.Get(in.place()); auto ctx = pool.Get(in.place());
switch (src_type) { switch (src_type) {
case proto::VarType::FP16:
framework::VisitDataType(dst_type,
CastDataType<platform::float16>(in, out, ctx));
break;
case proto::VarType::FP32: case proto::VarType::FP32:
framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx)); framework::VisitDataType(dst_type, CastDataType<float>(in, out, ctx));
break; break;
......
data_type_transform.cc
\ No newline at end of file
...@@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -22,32 +22,145 @@ TEST(DataTypeTransform, CPUTransform) {
auto place = CPUPlace(); auto place = CPUPlace();
Tensor in; auto kernel_fp16 = OpKernelType(proto::VarType::FP16, place,
Tensor out; DataLayout::kAnyLayout, LibraryType::kPlain);
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i / 3;
}
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place, auto kernel_fp32 = OpKernelType(proto::VarType::FP32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place, auto kernel_fp64 = OpKernelType(proto::VarType::FP64, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::VarType::INT32, place, auto kernel_int32 = OpKernelType(proto::VarType::INT32, place,
DataLayout::kAnyLayout, LibraryType::kPlain); DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int64 = OpKernelType(proto::VarType::INT64, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_bool = OpKernelType(proto::VarType::BOOL, place,
DataLayout::kAnyLayout, LibraryType::kPlain);
TransDataType(kernel_fp32, kernel_fp64, in, &out); // data type transform from float32
double* out_data_double = out.data<double>(); {
for (int i = 0; i < data_number; ++i) { Tensor in;
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3)); Tensor out;
float* ptr = in.mutable_data<float>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i / 3;
}
TransDataType(kernel_fp32, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3));
}
TransDataType(kernel_fp32, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3));
}
} }
TransDataType(kernel_fp32, kernel_int32, in, &out); // data type transform from/to float16
int* out_data_int = out.data<int>(); {
for (int i = 0; i < data_number; ++i) { Tensor in;
ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3)); Tensor out;
float16* ptr = in.mutable_data<float16>(make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from float16 to other data types
TransDataType(kernel_fp16, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int32, in, &out);
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to float16
float* in_data_float = in.mutable_data<float>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
TransDataType(kernel_fp32, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x);
}
// transform double to float16
double* in_data_double = in.mutable_data<double>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
TransDataType(kernel_fp64, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x);
}
// transform int to float16
int* in_data_int = in.mutable_data<int>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
TransDataType(kernel_int32, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x);
}
// transform int64 to float16
int64_t* in_data_int64 = in.mutable_data<int64_t>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
TransDataType(kernel_int64, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x);
}
// transform bool to float16
bool* in_data_bool = in.mutable_data<bool>(make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
TransDataType(kernel_bool, kernel_fp16, in, &out);
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x);
}
} }
} }
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "gtest/gtest.h"
TEST(DataTypeTransform, GPUTransform) {
using namespace paddle::framework;
using namespace paddle::platform;
auto cpu_place = CPUPlace();
auto gpu_place = CUDAPlace(0);
CUDADeviceContext context(gpu_place);
auto kernel_fp16 = OpKernelType(proto::VarType::FP16, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp32 = OpKernelType(proto::VarType::FP32, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_fp64 = OpKernelType(proto::VarType::FP64, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int32 = OpKernelType(proto::VarType::INT32, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_int64 = OpKernelType(proto::VarType::INT64, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
auto kernel_bool = OpKernelType(proto::VarType::BOOL, gpu_place,
DataLayout::kAnyLayout, LibraryType::kPlain);
// data type transform from float32
{
Tensor in;
Tensor in_gpu;
Tensor out_gpu;
Tensor out;
float* in_ptr = in.mutable_data<float>(make_ddim({2, 3}), cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
int data_number = sizeof(arr) / sizeof(arr[0]);
memcpy(in_ptr, arr, sizeof(arr));
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_fp32, kernel_fp64, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(arr[i]));
}
TransDataType(kernel_fp32, kernel_int32, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(arr[i]));
}
}
// data type transform from/to float16
{
Tensor in;
Tensor in_gpu;
Tensor out_gpu;
Tensor out;
float16* ptr = in.mutable_data<float16>(make_ddim({2, 3}), cpu_place);
float16 arr[6] = {float16(0), float16(1), float16(2),
float16(3), float16(4), float16(5)};
int data_number = sizeof(arr) / sizeof(arr[0]);
memcpy(ptr, arr, sizeof(arr));
TensorCopy(in, gpu_place, context, &in_gpu);
// transform from float16 to other data types
TransDataType(kernel_fp16, kernel_fp32, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_fp64, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int32, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
int* out_data_int = out.data<int>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_int64, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
TransDataType(kernel_fp16, kernel_bool, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to float16
float* in_data_float = in.mutable_data<float>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_fp32, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x);
}
// transform double to float16
double* in_data_double =
in.mutable_data<double>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_fp64, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x);
}
// transform int to float16
int* in_data_int = in.mutable_data<int>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_int[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_int32, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x);
}
// transform int64 to float16
int64_t* in_data_int64 =
in.mutable_data<int64_t>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_int64, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x);
}
// transform bool to float16
bool* in_data_bool = in.mutable_data<bool>(make_ddim({2, 3}), cpu_place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
TensorCopy(in, gpu_place, context, &in_gpu);
TransDataType(kernel_bool, kernel_fp16, in_gpu, &out_gpu);
TensorCopy(out_gpu, cpu_place, context, &out);
context.Wait();
ptr = out.data<float16>();
for (int i = 0; i < data_number; ++i) {
ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x);
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <deque>
#include <mutex>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace details {
// Four of the properties of Buffered Channel:
// - A send to a full channel blocks temporarily until a receive from the
// channel or the channel is closed.
// - A receive from an empty channel blocks temporarily until a send to the
// channel or the channel is closed.
// - A send to a closed channel returns false immediately.
// - A receive from a closed channel returns false immediately.
template <typename T>
class Buffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return cap_; }
virtual void Close();
virtual ~Buffered();
private:
size_t cap_;
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
std::condition_variable destructor_cond_var_;
std::deque<T> channel_;
std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
}
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
template <typename T>
bool Buffered<T>::Send(T* item) {
bool ret = false;
if (closed_) {
return ret;
}
send_ctr++;
std::unique_lock<std::mutex> lock(mu_);
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
if (!closed_) {
channel_.push_back(std::move(*item));
lock.unlock();
empty_cond_var_.notify_one();
ret = true;
}
send_ctr--;
destructor_cond_var_.notify_one();
return ret;
}
template <typename T>
bool Buffered<T>::Receive(T* item) {
bool ret = false;
// Once the channel has been closed and all data has been consumed,
// just return false. Don't even try acquiring the mutex.
if (closed_ && channel_.empty()) {
return false;
}
recv_ctr++;
std::unique_lock<std::mutex> lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
if (!channel_.empty()) {
*item = std::move(channel_.front());
channel_.pop_front();
full_cond_var_.notify_one();
ret = true;
}
recv_ctr--;
destructor_cond_var_.notify_one();
return ret;
}
template <typename T>
void Buffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
NotifyAllParticipants(&lock);
}
template <typename T>
Buffered<T>::~Buffered() {
std::unique_lock<std::mutex> lock(mu_);
closed_ = true;
channel_.clear();
NotifyAllParticipants(&lock);
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
lock.lock();
destructor_cond_var_.wait(
lock, [this]() { return send_ctr == 0 && recv_ctr == 0; });
}
template <typename T>
void Buffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
full_cond_var_.notify_all();
empty_cond_var_.notify_all();
}
} // namespace details
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "paddle/fluid/framework/channel.h"
namespace paddle {
namespace framework {
namespace details {
// Four of the properties of UnBuffered Channel:
// - A send to a channel blocks temporarily until a receive from the
// channel or the channel is closed.
// - A receive from a channel blocks temporarily until a send to the
// channel or the channel is closed.
// - A send to a closed channel returns false immediately.
// - A receive from a closed channel returns false immediately.
template <typename T>
class UnBuffered : public paddle::framework::Channel<T> {
friend Channel<T>* paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T>*);
public:
virtual bool Send(T*);
virtual bool Receive(T*);
virtual size_t Cap() { return 0; }
virtual void Close();
virtual ~UnBuffered();
private:
std::mutex mu_ch_;
// Mutex for readers and writers who are waiting for other reader
// and writer to complete execution
std::recursive_mutex mu_read_, mu_write_;
// reader_found_ is set true when a reader is ready to accept data
// writer_found_ is set true when a writer is ready to send data
// A transaction occurs only when both are true
std::atomic<bool> reader_found_{false}, writer_found_{false};
std::condition_variable cv_channel_;
std::condition_variable_any cv_reader_, cv_writer_, cv_destructor_;
T* item{nullptr};
std::atomic<bool> closed_{false};
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
UnBuffered() : closed_(false) {}
void NotifyAllParticipants(std::unique_lock<std::mutex>*);
};
// This function implements the concept of how data should
// be sent from a writer to a reader.
template <typename T>
bool UnBuffered<T>::Send(T* data) {
bool ret = false;
if (closed_) {
return ret;
}
send_ctr++;
// Prevent other writers from entering
std::unique_lock<std::recursive_mutex> writer_lock(mu_write_);
writer_found_ = true;
std::unique_lock<std::recursive_mutex> cv_lock(mu_write_);
// If writer comes first, it should wait till a reader arrives
cv_writer_.wait(cv_lock,
[this]() { return reader_found_ == true || closed_; });
cv_reader_.notify_one();
if (!closed_) {
std::unique_lock<std::mutex> channel_lock(mu_ch_);
item = data;
channel_lock.unlock();
cv_channel_.notify_one();
channel_lock.lock();
cv_channel_.wait(channel_lock,
[this]() { return item == nullptr || closed_; });
ret = true;
}
writer_found_ = false;
send_ctr--;
cv_destructor_.notify_one();
return ret;
}
// This function implements the concept of how
// data that was sent by a writer is read from a reader.
template <typename T>
bool UnBuffered<T>::Receive(T* data) {
bool ret = false;
// If channel is closed, we don't even want any reader to enter.
// Unlike a buffered channel, an unbuffered channel does not allow
// readers to read after closing because there is no buffer to be consumed.
if (closed_) return ret;
recv_ctr++;
// Prevent other readers from entering
std::unique_lock<std::recursive_mutex> read_lock{mu_read_};
reader_found_ = true;
std::unique_lock<std::recursive_mutex> cv_lock{mu_read_};
// If reader comes first, it should wait till a writer arrives
cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one();
if (!closed_) {
std::unique_lock<std::mutex> lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data
cv_channel_.wait(lock_ch, [this]() { return item != nullptr || closed_; });
if (!closed_) {
*data = std::move(*item);
item = nullptr;
lock_ch.unlock();
ret = true;
}
cv_channel_.notify_one();
}
reader_found_ = false;
recv_ctr--;
cv_destructor_.notify_one();
return ret;
}
// This function implements the sequence of events
// that take place once the channel is closed.
template <typename T>
void UnBuffered<T>::Close() {
if (closed_) {
return;
}
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
}
// This function implements the sequence of events
// that are executed once the object of an UnBuffered
// channel is destroyed.
template <typename T>
UnBuffered<T>::~UnBuffered() {
std::unique_lock<std::mutex> lock(mu_ch_);
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
lock.lock();
cv_destructor_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
// This function notifies all the readers, writers and
// the channel condition variables.
template <typename T>
void UnBuffered<T>::NotifyAllParticipants(std::unique_lock<std::mutex>* lock) {
lock->unlock();
cv_writer_.notify_all();
cv_channel_.notify_all();
cv_reader_.notify_all();
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -25,135 +25,5 @@ DDim ReaderBase::shape(size_t idx) const { ...@@ -25,135 +25,5 @@ DDim ReaderBase::shape(size_t idx) const {
return shapes_[idx]; return shapes_[idx];
} }
void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
buffer_.clear();
buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
out->clear();
if (!buffer_.empty()) {
std::swap(*out, buffer_[iteration_pos_++]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}
void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// Concat instances
out->clear();
if (buffer_.empty()) {
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
}
int out_num = buffer_[0].size();
out->reserve(out_num);
for (int j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
DDim batch_shape = buffer_[0][j].dims();
for (size_t i = 1; i < buffer_.size(); ++i) {
std::type_index ins_type = buffer_[i][j].type();
DDim ins_shape = buffer_[i][j].dims();
PADDLE_ENFORCE_EQ(batch_type, ins_type);
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
slice_ddim(ins_shape, 1, ins_shape.size()));
PADDLE_ENFORCE_GT(ins_shape[0], 0);
batch_shape[0] += ins_shape[0];
}
LoDTensor out_tensor;
out_tensor.Resize(batch_shape);
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
int64_t dst_offset = 0;
// Merge lod and data
LoD batch_lod;
for (size_t i = 0; i < buffer_.size(); ++i) {
DDim ins_shape = buffer_[i][j].dims();
LoD ins_lod = buffer_[i][j].lod();
if (i == 0) {
batch_lod = ins_lod;
} else {
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
auto& lod_level = batch_lod[level_idx];
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
}
}
}
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0];
}
out_tensor.set_lod(batch_lod);
out->push_back(out_tensor);
}
}
void DoubleBufferReader::ReadNext(std::vector<LoDTensor>* out) {
std::unique_lock<std::mutex> lck(mtx_);
while (write_pos_ == read_pos_) {
buffer_not_empty_.wait(lck);
}
out->clear();
out->reserve(buffer_[read_pos_].size());
// TODO(fengjiayi): This copy shall be reduced.
for (size_t i = 0; i < buffer_[read_pos_].size(); ++i) {
LoDTensor dst;
TensorCopy(buffer_[read_pos_][i], platform::CPUPlace(), &dst);
dst.set_lod(buffer_[read_pos_][i].lod());
out->push_back(dst);
}
++read_pos_;
if (read_pos_ >= kDoubleBufferSize) {
read_pos_ = 0;
}
buffer_not_full_.notify_all();
}
bool DoubleBufferReader::HasNext() const {
return reader_->HasNext() || !buffer_.empty();
}
void DoubleBufferReader::ProducerThreadFunc() {
while (reader_->HasNext()) {
std::unique_lock<std::mutex> lck(mtx_);
while (((write_pos_ + 1) % kDoubleBufferSize) == read_pos_) {
buffer_not_full_.wait(lck);
}
reader_->ReadNext(&buffer_[write_pos_]);
++write_pos_;
if (write_pos_ >= kDoubleBufferSize) {
write_pos_ = 0;
}
buffer_not_empty_.notify_all();
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -63,105 +63,8 @@ class DecoratedReader : public ReaderBase { ...@@ -63,105 +63,8 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_; ReaderBase* reader_;
}; };
// file readers // The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
template <typename T>
class RandomDataGenerator : public FileReader {
public:
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(min_, max_);
}
void ReadNext(std::vector<LoDTensor>* out) override {
out->clear();
out->reserve(shapes_.size());
for (const DDim& shape : shapes_) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
"The rank of reader's output data should be 2 at least.(Now it's %d)",
shape.size());
LoDTensor out_tensor;
out_tensor.Resize(shape);
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
int64_t numel = product(shape);
for (int64_t i = 0; i < numel; ++i) {
data[i] = dist_(engine_);
}
out->push_back(out_tensor);
}
}
bool HasNext() const override { return true; }
void ReInit() override { return; }
private:
float min_;
float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
};
// decorated readers
class ShuffleReader : public DecoratedReader {
public:
ShuffleReader(ReaderBase* reader, int buffer_size)
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
}
void ReadNext(std::vector<LoDTensor>* out) override;
private:
int buffer_size_;
std::vector<std::vector<LoDTensor>> buffer_;
size_t iteration_pos_;
};
class BatchReader : public DecoratedReader {
public:
BatchReader(ReaderBase* reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
}
void ReadNext(std::vector<LoDTensor>* out) override;
private:
int batch_size_;
std::vector<std::vector<LoDTensor>> buffer_;
};
class DoubleBufferReader : public DecoratedReader {
public:
explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader), buffer_(kDoubleBufferSize) {
framework::Async(std::bind(&DoubleBufferReader::ProducerThreadFunc, this));
}
void ReadNext(std::vector<LoDTensor>* out) override;
bool HasNext() const override;
private:
void ProducerThreadFunc();
std::vector<std::vector<LoDTensor>> buffer_;
size_t write_pos_;
size_t read_pos_;
std::mutex mtx_;
std::condition_variable buffer_not_full_;
std::condition_variable buffer_not_empty_;
};
// The ReaderHolder is used as readers' unified wrapper,
// making it easier to access different type readers in Variables.
class ReaderHolder { class ReaderHolder {
public: public:
void Reset(ReaderBase* reader) { reader_.reset(reader); } void Reset(ReaderBase* reader) { reader_.reset(reader); }
......
...@@ -235,27 +235,53 @@ TEST(TensorToVector, Tensor) { ...@@ -235,27 +235,53 @@ TEST(TensorToVector, Tensor) {
TEST(TensorContainsNAN, CPU) { TEST(TensorContainsNAN, CPU) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; {
float* buf = src.mutable_data<float>({3}, CPUPlace()); Tensor src;
buf[0] = 0.0; float* buf = src.mutable_data<float>({3}, CPUPlace());
buf[1] = NAN; buf[0] = 0.0;
buf[2] = 0.0; buf[1] = NAN;
ASSERT_TRUE(TensorContainsNAN(src)); buf[2] = 0.0;
buf[1] = 0.0; ASSERT_TRUE(TensorContainsNAN(src));
ASSERT_FALSE(TensorContainsNAN(src)); buf[1] = 0.0;
ASSERT_FALSE(TensorContainsNAN(src));
}
{
Tensor src;
float16* buf = src.mutable_data<float16>({3}, CPUPlace());
buf[0] = 0.0;
buf[1].x = 0x7fff;
buf[2] = 0.0;
ASSERT_TRUE(TensorContainsNAN(src));
buf[1] = 0.0;
ASSERT_FALSE(TensorContainsNAN(src));
}
} }
TEST(TensorContainsInf, CPU) { TEST(TensorContainsInf, CPU) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; {
double* buf = src.mutable_data<double>({3}, CPUPlace()); Tensor src;
buf[0] = 1.0; double* buf = src.mutable_data<double>({3}, CPUPlace());
buf[1] = INFINITY; buf[0] = 1.0;
buf[2] = 0.0; buf[1] = INFINITY;
ASSERT_TRUE(TensorContainsInf(src)); buf[2] = 0.0;
buf[1] = 1.0; ASSERT_TRUE(TensorContainsInf(src));
ASSERT_FALSE(TensorContainsInf(src)); buf[1] = 1.0;
ASSERT_FALSE(TensorContainsInf(src));
}
{
Tensor src;
float16* buf = src.mutable_data<float16>({3}, CPUPlace());
buf[0] = 1.0;
buf[1].x = 0x7c00;
buf[2] = 0.0;
ASSERT_TRUE(TensorContainsInf(src));
buf[1] = 1.0;
ASSERT_FALSE(TensorContainsInf(src));
}
} }
TEST(Tensor, FromAndToStream) { TEST(Tensor, FromAndToStream) {
......
...@@ -25,32 +25,65 @@ static __global__ void FillNAN(float* buf) { ...@@ -25,32 +25,65 @@ static __global__ void FillNAN(float* buf) {
buf[1] = 0.1; buf[1] = 0.1;
buf[2] = NAN; buf[2] = NAN;
} }
static __global__ void FillInf(float* buf) { static __global__ void FillInf(float* buf) {
buf[0] = 0.0; buf[0] = 0.0;
buf[1] = INFINITY; buf[1] = INFINITY;
buf[2] = 0.5; buf[2] = 0.5;
} }
static __global__ void FillNAN(platform::float16* buf) {
buf[0] = 0.0;
buf[1] = 0.1;
buf[2].x = 0x7fff;
}
static __global__ void FillInf(platform::float16* buf) {
buf[0] = 0.0;
buf[1].x = 0x7c00;
buf[2] = 0.5;
}
TEST(TensorContainsNAN, GPU) { TEST(TensorContainsNAN, GPU) {
Tensor tensor; using namespace paddle::platform;
platform::CUDAPlace gpu(0); CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu); auto* cuda_ctx = pool.GetByPlace(gpu);
float* buf = tensor.mutable_data<float>({3}, gpu); {
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); Tensor tensor;
cuda_ctx->Wait(); float* buf = tensor.mutable_data<float>({3}, gpu);
ASSERT_TRUE(TensorContainsNAN(tensor)); FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsNAN(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsNAN(tensor));
}
} }
TEST(TensorContainsInf, GPU) { TEST(TensorContainsInf, GPU) {
Tensor tensor; using namespace paddle::platform;
platform::CUDAPlace gpu(0); CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu); auto* cuda_ctx = pool.GetByPlace(gpu);
float* buf = tensor.mutable_data<float>({3}, gpu); {
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); Tensor tensor;
cuda_ctx->Wait(); float* buf = tensor.mutable_data<float>({3}, gpu);
ASSERT_TRUE(TensorContainsInf(tensor)); FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsInf(tensor));
}
{
Tensor tensor;
float16* buf = tensor.mutable_data<float16>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(TensorContainsInf(tensor));
}
} }
} // namespace framework } // namespace framework
......
...@@ -70,7 +70,7 @@ function(op_library TARGET) ...@@ -70,7 +70,7 @@ function(op_library TARGET)
endif() endif()
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "create_reader_op") foreach(manual_pybind_op "net_op" "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
...@@ -128,8 +128,8 @@ else() ...@@ -128,8 +128,8 @@ else()
set(DEPS_OPS ${DEPS_OPS} nccl_op) set(DEPS_OPS ${DEPS_OPS} nccl_op)
endif() endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(detail)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_op DEPS ${DISTRIBUTE_DEPS})
...@@ -170,7 +170,6 @@ op_library(recurrent_op DEPS executor) ...@@ -170,7 +170,6 @@ op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
op_library(create_reader_op DEPS reader)
if (WITH_GPU) if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv) op_library(conv_op DEPS vol2col depthwise_conv)
...@@ -189,7 +188,12 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) ...@@ -189,7 +188,12 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
op_library(${src}) op_library(${src})
endforeach() endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
add_subdirectory(reader)
foreach(src ${READER_LIBRARY})
set(OP_LIBRARY ${src} ${OP_LIBRARY})
endforeach()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace paddle {
namespace operators {
static std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
std::vector<framework::DDim> res;
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
return res;
}
// general infershape for file readers
class CreateFileReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat =
ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(
lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
}
};
// general infershape for decorated readers
class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
}
};
// general var type inference for file readers
class CreateFileReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarType::READER);
}
};
// general var type inference for decorated readers
class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
};
template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
}
};
class CreateRandomDataGeneratorOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat",
"The concat of all data's shapes.");
AddAttr<std::vector<int>>(
"ranks",
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
)DOC");
}
};
class CreateShuffleReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::ShuffleReader(underlying_reader.Get(),
Attr<int>("buffer_size")));
}
};
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC(
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order.
)DOC");
}
};
class CreateBatchReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::BatchReader(underlying_reader.Get(),
Attr<int>("batch_size")));
}
};
class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput(
"UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
AddComment(R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_data_generator,
ops::CreateRandomDataGeneratorOp<float>,
ops::CreateFileReaderInferShape,
ops::CreateRandomDataGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateFileReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateShuffleReaderOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateDecoratedReaderInferVarType);
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
ops::CreateDecoratedReaderInferShape,
ops::CreateBatchReaderOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateDecoratedReaderInferVarType);
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
endif()
...@@ -142,7 +142,15 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -142,7 +142,15 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("MAP", AddOutput("MAP",
"(Tensor) A tensor with shape [1], store the mAP evaluate " "(Tensor) A tensor with shape [1], store the mAP evaluate "
"result of the detection."); "result of the detection.");
AddAttr<int>("class_num",
"(int) "
"The class number.");
AddAttr<int>(
"background_label",
"(int, defalut: 0) "
"The index of background label, the background label will be ignored. "
"If set to -1, then all categories will be considered.")
.SetDefault(0);
AddAttr<float>( AddAttr<float>(
"overlap_threshold", "overlap_threshold",
"(float) " "(float) "
......
...@@ -69,6 +69,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -69,6 +69,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
float overlap_threshold = ctx.Attr<float>("overlap_threshold"); float overlap_threshold = ctx.Attr<float>("overlap_threshold");
float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult"); float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult");
auto ap_type = GetAPType(ctx.Attr<std::string>("ap_type")); auto ap_type = GetAPType(ctx.Attr<std::string>("ap_type"));
int class_num = ctx.Attr<int>("class_num");
auto label_lod = in_label->lod(); auto label_lod = in_label->lod();
auto detect_lod = in_detect->lod(); auto detect_lod = in_detect->lod();
...@@ -95,17 +96,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -95,17 +96,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
if (in_pos_count != nullptr && state) { if (in_pos_count != nullptr && state) {
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count, GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
true_pos, false_pos); true_pos, false_pos, class_num);
} }
CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult, CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult,
overlap_threshold, label_pos_count, true_pos, overlap_threshold, label_pos_count, true_pos,
false_pos); false_pos);
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); int background_label = ctx.Attr<int>("background_label");
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos,
background_label);
GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count, GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count,
*out_true_pos, *out_false_pos); *out_true_pos, *out_false_pos, class_num);
T* map_data = out_map->mutable_data<T>(ctx.GetPlace()); T* map_data = out_map->mutable_data<T>(ctx.GetPlace());
map_data[0] = map; map_data[0] = map;
...@@ -190,24 +193,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -190,24 +193,20 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
const std::map<int, std::vector<std::pair<T, int>>>& false_pos, const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
framework::Tensor& output_pos_count, framework::Tensor& output_pos_count,
framework::LoDTensor& output_true_pos, framework::LoDTensor& output_true_pos,
framework::LoDTensor& output_false_pos) const { framework::LoDTensor& output_false_pos, const int class_num) const {
int max_class_id = 0;
int true_pos_count = 0; int true_pos_count = 0;
int false_pos_count = 0; int false_pos_count = 0;
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { for (auto it = true_pos.begin(); it != true_pos.end(); ++it) {
int label = it->first; auto tp = it->second;
if (label > max_class_id) max_class_id = label; true_pos_count += tp.size();
int label_num_pos = it->second; }
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) for (auto it = false_pos.begin(); it != false_pos.end(); ++it) {
continue; auto fp = it->second;
auto label_true_pos = true_pos.find(label)->second; false_pos_count += fp.size();
auto label_false_pos = false_pos.find(label)->second;
true_pos_count += label_true_pos.size();
false_pos_count += label_false_pos.size();
} }
int* pos_count_data = output_pos_count.mutable_data<int>( int* pos_count_data = output_pos_count.mutable_data<int>(
framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace()); framework::make_ddim({class_num, 1}), ctx.GetPlace());
T* true_pos_data = output_true_pos.mutable_data<T>( T* true_pos_data = output_true_pos.mutable_data<T>(
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace()); framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
...@@ -217,7 +216,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -217,7 +216,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
false_pos_count = 0; false_pos_count = 0;
std::vector<size_t> true_pos_starts = {0}; std::vector<size_t> true_pos_starts = {0};
std::vector<size_t> false_pos_starts = {0}; std::vector<size_t> false_pos_starts = {0};
for (int i = 0; i <= max_class_id; ++i) { for (int i = 0; i < class_num; ++i) {
auto it_count = label_pos_count.find(i); auto it_count = label_pos_count.find(i);
pos_count_data[i] = 0; pos_count_data[i] = 0;
if (it_count != label_pos_count.end()) { if (it_count != label_pos_count.end()) {
...@@ -258,17 +257,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -258,17 +257,16 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
return; return;
} }
void GetInputPos( void GetInputPos(const framework::Tensor& input_pos_count,
const framework::Tensor& input_pos_count, const framework::LoDTensor& input_true_pos,
const framework::LoDTensor& input_true_pos, const framework::LoDTensor& input_false_pos,
const framework::LoDTensor& input_false_pos, std::map<int, int>& label_pos_count,
std::map<int, int>& label_pos_count, std::map<int, std::vector<std::pair<T, int>>>& true_pos,
std::map<int, std::vector<std::pair<T, int>>>& true_pos, std::map<int, std::vector<std::pair<T, int>>>& false_pos,
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const { const int class_num) const {
constexpr T kEPS = static_cast<T>(1e-6); constexpr T kEPS = static_cast<T>(1e-6);
int class_number = input_pos_count.dims()[0];
const int* pos_count_data = input_pos_count.data<int>(); const int* pos_count_data = input_pos_count.data<int>();
for (int i = 0; i < class_number; ++i) { for (int i = 0; i < class_num; ++i) {
label_pos_count[i] = pos_count_data[i]; label_pos_count[i] = pos_count_data[i];
} }
...@@ -391,17 +389,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -391,17 +389,19 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
} }
} }
T CalcMAP( T CalcMAP(APType ap_type, const std::map<int, int>& label_pos_count,
APType ap_type, const std::map<int, int>& label_pos_count, const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
const std::map<int, std::vector<std::pair<T, int>>>& true_pos, const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
const std::map<int, std::vector<std::pair<T, int>>>& false_pos) const { const int background_label) const {
T mAP = 0.0; T mAP = 0.0;
int count = 0; int count = 0;
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first; int label = it->first;
int label_num_pos = it->second; int label_num_pos = it->second;
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) if (label_num_pos == background_label ||
true_pos.find(label) == true_pos.end()) {
continue; continue;
}
auto label_true_pos = true_pos.find(label)->second; auto label_true_pos = true_pos.find(label)->second;
auto label_false_pos = false_pos.find(label)->second; auto label_false_pos = false_pos.find(label)->second;
// Compute average precision. // Compute average precision.
...@@ -450,7 +450,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { ...@@ -450,7 +450,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
} }
} }
if (count != 0) mAP /= count; if (count != 0) mAP /= count;
return mAP * 100; return mAP;
} }
}; // namespace operators }; // namespace operators
......
...@@ -40,80 +40,14 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -40,80 +40,14 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
struct ElementwiseMulGradFunctor { struct IdentityGrad_DX {
template <typename Device, typename X, typename Y, typename Z, typename dX, HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; }
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e * y_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = x_e * dz_e;
}
}
}; };
template <typename T> template <typename T>
struct ElementwiseMulBroadCastGradFunctor { struct IdentityGrad_DY {
template <typename Device, typename X, typename Y, typename Z, typename dX, HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; }
typename dY, typename dZ, typename Pre, typename N>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e * y_e_bcast;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (x_e * dz_e)
.reshape(Eigen::DSizes<int, 2>(pre, n))
.sum(Eigen::array<int, 1>{{0}});
}
}
}; };
template <typename T>
struct ElementwiseMulBroadCast2GradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N, typename Post>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
Post post) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e * y_e_bcast;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (x_e * dz_e)
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
.sum(Eigen::array<int, 2>{{0, 2}});
}
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> { class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -127,12 +61,11 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> { ...@@ -127,12 +61,11 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>, ElemwiseGradCompute<DeviceContext, T, IdentityGrad_DX<T>,
ElementwiseMulBroadCastGradFunctor<T>, IdentityGrad_DY<T>>(ctx, *x, *y, *out, *dout, axis, dx,
ElementwiseMulBroadCast2GradFunctor<T>>( dy, IdentityGrad_DX<T>(),
ctx, x, y, out, dout, axis, dx, dy); IdentityGrad_DY<T>());
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -301,7 +301,7 @@ struct ElemwiseGradNoBroadcast { ...@@ -301,7 +301,7 @@ struct ElemwiseGradNoBroadcast {
dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]);
} }
if (dy_ != nullptr) { if (dy_ != nullptr) {
dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]);
} }
} }
......
...@@ -245,11 +245,13 @@ template struct SetConstant<platform::CPUDeviceContext, int>; ...@@ -245,11 +245,13 @@ template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>; template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>; template struct SetConstant<platform::CPUDeviceContext, bool>;
#define DEFINE_CPU_TRANS(RANK) \ #define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \ template struct Transpose<platform::CPUDeviceContext, platform::float16, \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \ RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \ template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \ template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; template struct Transpose<platform::CPUDeviceContext, bool, RANK>;
DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(1);
......
...@@ -324,7 +324,7 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -324,7 +324,7 @@ class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
" Please note, M is equal to the 1st dimension of BBoxes. "); " Please note, M is equal to the 1st dimension of BBoxes. ");
AddAttr<int>( AddAttr<int>(
"background_label", "background_label",
"(int64_t, defalut: 0) " "(int, defalut: 0) "
"The index of background label, the background label will be ignored. " "The index of background label, the background label will be ignored. "
"If set to -1, then all categories will be considered.") "If set to -1, then all categories will be considered.")
.SetDefault(0); .SetDefault(0);
......
...@@ -16,5 +16,50 @@ limitations under the License. */ ...@@ -16,5 +16,50 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace platform {} // namespace platform namespace platform {
namespace {
// TODO(panyx0718): Where to destroy them.
std::unique_ptr<std::vector<ncclComm_t>> global_comms;
std::unique_ptr<std::unordered_map<int, int>> comm_id_map;
bool inited = false;
size_t last_num_gpus = -1;
// TODO(panyx0718): Need to decide whether Paddle supports parallel
// runs with different number GPUs. If true, current solution is not enough.
std::mutex comm_mu;
}
int Communicator::GetCommId(int device_id) const {
std::lock_guard<std::mutex> guard(comm_mu);
return comm_id_map->at(device_id);
}
void Communicator::InitAll(const std::vector<int>& gpus) {
std::lock_guard<std::mutex> guard(comm_mu);
if (inited && last_num_gpus == gpus.size()) {
return;
}
last_num_gpus = gpus.size();
if (global_comms) {
for (size_t i = 0; i < global_comms->size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy((*global_comms)[i]);
}
}
global_comms.reset(new std::vector<ncclComm_t>());
comm_id_map.reset(new std::unordered_map<int, int>());
global_comms->resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
(*comm_id_map)[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
inited = true;
}
const std::vector<ncclComm_t>& Communicator::comms() const {
std::lock_guard<std::mutex> guard(comm_mu);
return *global_comms;
}
} // namespace platform
} // namespace paddle } // namespace paddle
...@@ -29,39 +29,16 @@ limitations under the License. */ ...@@ -29,39 +29,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
constexpr int kInvalidGPUId = -1; constexpr int kInvalidGPUId = -1;
struct Communicator { struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
bool inited_;
Communicator() {} Communicator() {}
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); } int GetCommId(int device_id) const;
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
inited_ = false;
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
inited_ = true;
}
~Communicator() { void InitAll(const std::vector<int>& gpus);
if (inited_) {
for (size_t i = 0; i < comms_.size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy(comms_[i]);
}
}
}
DISABLE_COPY_AND_ASSIGN(Communicator); const std::vector<ncclComm_t>& comms() const;
}; };
} // namespace platform } // namespace platform
......
...@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_, outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream)); comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " VLOG(1) << "gpu : "
...@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) { if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size(); root = hasher(ins_names[i]) % comm->comms().size();
} }
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == gpu_id) { if (root == gpu_id) {
...@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(), ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx], NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
stream)); stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
...@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
VLOG(1) << " before ncclBcast"; VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type, (void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream)); root, comm->comms().at(idx), stream));
VLOG(1) << " after ncclBcast"; VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
...@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(), outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream)); NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
......
...@@ -17,8 +17,15 @@ limitations under the License. */ ...@@ -17,8 +17,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
int PoolOutputSize(int input_size, int filter_size, int padding, int stride) { int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
int output_size = (input_size - filter_size + 2 * padding) / stride + 1; bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
}
PADDLE_ENFORCE(output_size > 0, PADDLE_ENFORCE(output_size > 0,
"Due to the settings of padding(%d), filter_size(%d) and " "Due to the settings of padding(%d), filter_size(%d) and "
"stride(%d), the output size is less than 0, please check " "stride(%d), the output size is less than 0, please check "
...@@ -38,6 +45,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -38,6 +45,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
bool ceil_mode = ctx->Attrs().Get<bool>("ceil_mode");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor."); "Pooling intput should be 4-D or 5-D tensor.");
...@@ -59,8 +67,8 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -59,8 +67,8 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]}); std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back( output_shape.push_back(PoolOutputSize(in_x_dims[i + 2], ksize[i],
PoolOutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); paddings[i], strides[i], ceil_mode));
} }
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
...@@ -167,6 +175,12 @@ Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker) ...@@ -167,6 +175,12 @@ Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>(
"ceil_mode",
"(bool, default false) Wether to use the ceil function to calculate "
"output height and width. False is the default. If it is set to False, "
"the floor function will be used.")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
...@@ -187,16 +201,21 @@ Parameters(ksize, strides, paddings) are two elements. ...@@ -187,16 +201,21 @@ Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively. These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different. The input(X) size and output(Out) size may be different.
Example: Example:
Input: Input:
X shape: $(N, C, H_{in}, W_{in})$ X shape: $(N, C, H_{in}, W_{in})$
Output: Output:
Out shape: $(N, C, H_{out}, W_{out})$ Out shape: $(N, C, H_{out}, W_{out})$
Where For ceil_mode = false:
$$ $$
H_{out} = \frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ H_{out} = \frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 W_{out} = \frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1
$$ $$
For ceil_mode = true:
$$
H_{out} = \frac{(H_{in} - ksize[0] + 2 * paddings[0] + strides[0] - 1)}{strides[0]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
$$
)DOC"); )DOC");
} }
...@@ -251,6 +270,12 @@ Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker) ...@@ -251,6 +270,12 @@ Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>(
"ceil_mode",
"(bool, default false) Wether to use the ceil function to calculate "
"output height and width. False is the default. If it is set to False, "
"the floor function will be used.")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
...@@ -267,8 +292,8 @@ The pooling3d operation calculates the output based on ...@@ -267,8 +292,8 @@ The pooling3d operation calculates the output based on
the input, pooling_type, ksize, strides, and paddings parameters. the input, pooling_type, ksize, strides, and paddings parameters.
Input(X) and output(Out) are in NCDHW format, where N is batch Input(X) and output(Out) are in NCDHW format, where N is batch
size, C is the number of channels, and D, H and W are the depth, height and size, C is the number of channels, and D, H and W are the depth, height and
width of the feature, respectively. Parameters(ksize, strides, paddings) width of the feature, respectively. Parameters(ksize, strides, paddings)
are three elements. These three elements represent depth, height and are three elements. These three elements represent depth, height and
width, respectively. The input(X) size and output(Out) size may be different. width, respectively. The input(X) size and output(Out) size may be different.
Example: Example:
...@@ -276,12 +301,18 @@ Example: ...@@ -276,12 +301,18 @@ Example:
X shape: $(N, C, D_{in}, H_{in}, W_{in})$ X shape: $(N, C, D_{in}, H_{in}, W_{in})$
Output: Output:
Out shape: $(N, C, D_{out}, H_{out}, W_{out})$ Out shape: $(N, C, D_{out}, H_{out}, W_{out})$
Where For ceil_mode = false:
$$ $$
D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\
H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 \\ H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1 W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1
$$ $$
For ceil_mode = true:
$$
D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0] + strides[0] -1)}{strides[0]} + 1 \\
H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1 \\
W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1
$$
)DOC"); )DOC");
} }
......
cc_library(reader_op_registry SRCS reader_op_registry.cc DEPS operator op_registry reader)
op_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc DEPS reader_op_registry)
op_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc DEPS reader_op_registry)
op_library(create_batch_reader_op SRCS create_batch_reader_op.cc DEPS reader_op_registry)
set(READER_LIBRARY create_random_data_generator_op create_shuffle_reader_op create_batch_reader_op PARENT_SCOPE)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class BatchReader : public framework::DecoratedReader {
public:
BatchReader(ReaderBase* reader, int batch_size)
: DecoratedReader(reader), batch_size_(batch_size) {
buffer_.reserve(batch_size_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
private:
int batch_size_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
};
class CreateBatchReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(
new BatchReader(underlying_reader.Get(), Attr<int>("batch_size")));
}
};
class CreateBatchReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("batch_size",
"How many instances the batch reader yields each time.")
.GreaterThan(0);
AddComment(R"DOC(
CreateBatchReader Operator
A batch reader takes another reader as its 'underlying reader',
gathers the underlying reader's outputs and then yields them in batches.
)DOC");
}
};
void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// Concat instances
out->clear();
if (buffer_.empty()) {
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
}
int out_num = buffer_[0].size();
out->reserve(out_num);
for (int j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
framework::DDim batch_shape = buffer_[0][j].dims();
for (size_t i = 1; i < buffer_.size(); ++i) {
std::type_index ins_type = buffer_[i][j].type();
framework::DDim ins_shape = buffer_[i][j].dims();
PADDLE_ENFORCE_EQ(batch_type, ins_type);
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
slice_ddim(ins_shape, 1, ins_shape.size()));
PADDLE_ENFORCE_GT(ins_shape[0], 0);
batch_shape[0] += ins_shape[0];
}
framework::LoDTensor out_tensor;
out_tensor.Resize(batch_shape);
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
int64_t dst_offset = 0;
// Merge lod and data
framework::LoD batch_lod;
for (size_t i = 0; i < buffer_.size(); ++i) {
framework::DDim ins_shape = buffer_[i][j].dims();
framework::LoD ins_lod = buffer_[i][j].lod();
if (i == 0) {
batch_lod = ins_lod;
} else {
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
auto& lod_level = batch_lod[level_idx];
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
}
}
}
auto dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
TensorCopy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0];
}
out_tensor.set_lod(batch_lod);
out->push_back(out_tensor);
}
}
} // namespace reader
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_batch_reader,
ops::CreateBatchReaderOp,
ops::CreateBatchReaderOpMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
template <typename T>
class RandomDataGenerator : public framework::FileReader {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max)
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()();
engine_.seed(seed);
dist_ = std::uniform_real_distribution<float>(min_, max_);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
out->clear();
out->reserve(shapes_.size());
for (const framework::DDim& shape : shapes_) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
"The rank of reader's output data should be 2 at least.(Now it's %d)",
shape.size());
framework::LoDTensor out_tensor;
out_tensor.Resize(shape);
T* data = out_tensor.mutable_data<T>(platform::CPUPlace());
int64_t numel = framework::product(shape);
for (int64_t i = 0; i < numel; ++i) {
data[i] = dist_(engine_);
}
out->push_back(out_tensor);
}
}
bool HasNext() const override { return true; }
void ReInit() override { return; }
private:
float min_;
float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
};
template <typename T>
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
}
};
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
public:
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Generated data follow an uniform distribution between 'min' and 'max'.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_random_data_generator,
ops::CreateRandomDataGeneratorOp<float>,
ops::CreateRandomDataGeneratorOpMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
class ShuffleReader : public framework::DecoratedReader {
public:
ShuffleReader(ReaderBase* reader, int buffer_size)
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) {
buffer_.reserve(buffer_size);
}
void ReadNext(std::vector<framework::LoDTensor>* out) override;
private:
int buffer_size_;
std::vector<std::vector<framework::LoDTensor>> buffer_;
size_t iteration_pos_;
};
void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
buffer_.clear();
buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<framework::LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
out->clear();
if (!buffer_.empty()) {
std::swap(*out, buffer_[iteration_pos_++]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}
class CreateShuffleReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(
new ShuffleReader(underlying_reader.Get(), Attr<int>("buffer_size")));
}
};
class CreateShuffleReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
AddComment(R"DOC(
CreateShuffleReader Operator
A shuffle reader takes another reader as its 'underlying reader'
and yields the underlying reader's outputs in a shuffled order.
)DOC");
}
};
} // namespace reader
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_shuffle_reader,
ops::CreateShuffleReaderOp,
ops::CreateShuffleReaderOpMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "reader_op_registry.h"
namespace paddle {
namespace operators {
namespace reader {
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
const std::vector<int>& ranks) {
std::vector<framework::DDim> res;
int offset = 0;
for (int len : ranks) {
auto start_it = shape_concat.begin() + offset;
auto end_it = start_it + len;
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it)));
offset += len;
}
return res;
}
FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat", "The concat of all data's shapes.");
AddAttr<std::vector<int>>(
"ranks",
"The ranks of each data."
"e.g."
"shape_concat = [2,3,4,5,6]"
"ranks = [3,2]"
"It means the reader will generate two data each time,"
"whose shapes are [2,3,4] and [5,6] respectively.");
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
}
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
}
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const {
std::string reader_name = op_desc.Output("Out")[0];
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
reader->SetType(framework::proto::VarType::READER);
}
void DecoratedReaderInferShape::operator()(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
}
void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
std::string out_reader_name = op_desc.Output("Out")[0];
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
out_reader->SetType(framework::proto::VarType::READER);
out_reader->SetDataTypes(in_reader->GetDataTypes());
}
DecoratedReaderMakerBase::DecoratedReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddInput("UnderlyingReader",
"(ReaderHolder) The underlying reader for creating a batch reader.");
AddOutput("Out", "(ReaderHolder) The created batch reader.");
}
} // namespace reader
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
namespace paddle {
namespace operators {
namespace reader {
extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);
class FileReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public:
FileReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
};
class FileReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override;
};
class FileReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override;
};
// general infershape for decorated reader
class DecoratedReaderInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override;
};
// general var type inference for decorated reader
class DecoratedReaderInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override;
};
class DecoratedReaderMakerBase : public framework::OpProtoAndCheckerMaker {
public:
DecoratedReaderMakerBase(OpProto* op_proto, OpAttrChecker* op_checker);
};
} // namespace reader
} // namespace operators
} // namespace paddle
#define REGISTER_FILE_READER_OPERATOR(op_name, ...) \
REGISTER_OPERATOR(op_name, __VA_ARGS__, \
paddle::operators::reader::FileReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::FileReaderInferVarType)
#define REGISTER_DECORATED_READER_OPERATOR(op_name, ...) \
REGISTER_OPERATOR(op_name, __VA_ARGS__, \
paddle::operators::reader::DecoratedReaderInferShape, \
paddle::framework::EmptyGradOpMaker, \
paddle::operators::reader::DecoratedReaderInferVarType)
...@@ -192,6 +192,12 @@ class DeviceTracerImpl : public DeviceTracer { ...@@ -192,6 +192,12 @@ class DeviceTracerImpl : public DeviceTracer {
} }
void AddCPURecords(const char *anno, uint64_t start_ns, uint64_t end_ns) { void AddCPURecords(const char *anno, uint64_t start_ns, uint64_t end_ns) {
if (!anno) {
// TODO(panyx0718): Currently, it doesn't support nested situation
// Up-level can be cleared by low-level and therefore get nullptr
// here.
return;
}
std::lock_guard<std::mutex> l(trace_mu_); std::lock_guard<std::mutex> l(trace_mu_);
cpu_records_.push_back( cpu_records_.push_back(
CPURecord{anno, start_ns, end_ns, CPURecord{anno, start_ns, end_ns,
......
...@@ -20,10 +20,6 @@ limitations under the License. */ ...@@ -20,10 +20,6 @@ limitations under the License. */
#include <cuda.h> #include <cuda.h>
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
#include "unsupported/Eigen/CXX11/Tensor"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __GNUC__ #ifdef __GNUC__
#define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__) #define PADDLE_GNUC_VER (__GNUC__ * 10 + __GNUC_MINOR__)
#else #else
...@@ -64,6 +60,18 @@ limitations under the License. */ ...@@ -64,6 +60,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
// Forward declare float16 for eigen.h
struct float16;
} // namespace platform
} // namespace paddle
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace platform {
// Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated // Use PADDLE_ALIGNED(2) to ensure that each float16 will be allocated
// and aligned at least on a 2-byte boundary, which leads to efficient // and aligned at least on a 2-byte boundary, which leads to efficient
// memory access of float16 struct and also makes float16 compatible // memory access of float16 struct and also makes float16 compatible
...@@ -729,6 +737,22 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) { ...@@ -729,6 +737,22 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
} }
#endif #endif
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hisnan(half(a));
#else
return (a.x & 0x7fff) > 0x7c00;
#endif
}
HOSTDEVICE inline bool(isinf)(const float16& a) {
return (a.x & 0x7fff) == 0x7c00;
}
HOSTDEVICE inline bool(isfinite)(const float16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -750,3 +774,27 @@ struct is_pod<paddle::platform::float16> { ...@@ -750,3 +774,27 @@ struct is_pod<paddle::platform::float16> {
}; };
} // namespace std } // namespace std
namespace Eigen {
namespace numext {
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(
const paddle::platform::float16& a) {
return (paddle::platform::isnan)(a);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(
const paddle::platform::float16& a) {
return (paddle::platform::isinf)(a);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(
const paddle::platform::float16& a) {
return (paddle::platform::isfinite)(a);
}
} // namespace numext
} // namespace Eigen
# internal library.
cc_library(header SRCS header.cc)
cc_test(header_test SRCS header_test.cc DEPS header)
cc_library(chunk SRCS chunk.cc DEPS snappystream snappy header zlib)
cc_test(chunk_test SRCS chunk_test.cc DEPS chunk)
cc_library(recordio DEPS chunk header)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/chunk.h"
#include <memory>
#include <sstream>
#include "paddle/fluid/platform/enforce.h"
#include "snappystream.hpp"
#include "zlib.h"
namespace paddle {
namespace recordio {
constexpr size_t kMaxBufSize = 1024;
template <typename Callback>
static void ReadStreamByBuf(std::istream& in, int limit, Callback callback) {
char buf[kMaxBufSize];
std::streamsize actual_size;
size_t counter = 0;
do {
auto actual_max =
limit > 0 ? std::min(limit - counter, kMaxBufSize) : kMaxBufSize;
actual_size = in.readsome(buf, actual_max);
if (actual_size == 0) {
break;
}
callback(buf, actual_size);
if (limit > 0) {
counter += actual_size;
}
} while (actual_size == kMaxBufSize);
}
static void PipeStream(std::istream& in, std::ostream& os) {
ReadStreamByBuf(
in, -1, [&os](const char* buf, size_t len) { os.write(buf, len); });
}
static uint32_t Crc32Stream(std::istream& in, int limit = -1) {
auto crc = crc32(0, nullptr, 0);
ReadStreamByBuf(in, limit, [&crc](const char* buf, size_t len) {
crc = crc32(crc, reinterpret_cast<const Bytef*>(buf), len);
});
return crc;
}
bool Chunk::Write(std::ostream& os, Compressor ct) const {
// NOTE(dzhwinter): don't check records.numBytes instead, because
// empty records are allowed.
if (records_.empty()) {
return false;
}
std::stringstream sout;
std::unique_ptr<std::ostream> compressed_stream;
switch (ct) {
case Compressor::kNoCompress:
break;
case Compressor::kSnappy:
compressed_stream.reset(new snappy::oSnappyStream(sout));
break;
default:
PADDLE_THROW("Not implemented");
}
std::ostream& buf_stream = compressed_stream ? *compressed_stream : sout;
for (auto& record : records_) {
size_t sz = record.size();
buf_stream.write(reinterpret_cast<const char*>(&sz), sizeof(uint32_t))
.write(record.data(), record.size());
}
if (compressed_stream) {
compressed_stream.reset();
}
auto end_pos = sout.tellg();
sout.seekg(0, std::ios::beg);
uint32_t len = static_cast<uint32_t>(end_pos - sout.tellg());
uint32_t crc = Crc32Stream(sout);
sout.seekg(0, std::ios::beg);
Header hdr(static_cast<uint32_t>(records_.size()), crc, ct, len);
hdr.Write(os);
PipeStream(sout, os);
return true;
}
void Chunk::Parse(std::istream& sin) {
Header hdr;
hdr.Parse(sin);
auto beg_pos = sin.tellg();
auto crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
Clear();
sin.seekg(beg_pos, std::ios::beg);
std::unique_ptr<std::istream> compressed_stream;
switch (hdr.CompressType()) {
case Compressor::kNoCompress:
break;
case Compressor::kSnappy:
compressed_stream.reset(new snappy::iSnappyStream(sin));
break;
default:
PADDLE_THROW("Not implemented");
}
std::istream& stream = compressed_stream ? *compressed_stream : sin;
for (uint32_t i = 0; i < hdr.NumRecords(); ++i) {
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
Add(buf);
}
}
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/recordio/header.h"
namespace paddle {
namespace recordio {
// A Chunk contains the Header and optionally compressed records.
class Chunk {
public:
Chunk() : num_bytes_(0) {}
void Add(std::string buf) {
records_.push_back(buf);
num_bytes_ += buf.size();
}
// dump the chunk into w, and clears the chunk and makes it ready for
// the next add invocation.
bool Write(std::ostream& fo, Compressor ct) const;
void Clear() {
records_.clear();
num_bytes_ = 0;
}
void Parse(std::istream& sin);
size_t NumBytes() { return num_bytes_; }
const std::string& Record(int i) const { return records_[i]; }
private:
std::vector<std::string> records_;
// sum of record lengths in bytes.
size_t num_bytes_;
DISABLE_COPY_AND_ASSIGN(Chunk);
};
size_t CompressData(const char* in, size_t in_length, Compressor ct, char* out);
void DeflateData(const char* in, size_t in_length, Compressor ct, char* out);
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/chunk.h"
#include <sstream>
#include "gtest/gtest.h"
using namespace paddle::recordio;
TEST(Chunk, SaveLoad) {
Chunk ch;
ch.Add(std::string("12345", 6));
ch.Add(std::string("123", 4));
std::stringstream ss;
ch.Write(ss, Compressor::kNoCompress);
ch.Clear();
ch.Parse(ss);
ASSERT_EQ(ch.NumBytes(), 10U);
}
TEST(Chunk, Compressor) {
Chunk ch;
ch.Add(std::string("12345", 6));
ch.Add(std::string("123", 4));
ch.Add(std::string("123", 4));
ch.Add(std::string("123", 4));
std::stringstream ss;
ch.Write(ss, Compressor::kSnappy);
std::stringstream ss2;
ch.Write(ss2, Compressor::kNoCompress);
ASSERT_LE(ss.tellp(), ss2.tellp()); // Compress should contain less data;
ch.Clear();
ch.Parse(ss);
ASSERT_EQ(ch.NumBytes(), 18);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/header.h"
namespace paddle {
namespace recordio {
Header::Header()
: num_records_(0),
checksum_(0),
compressor_(Compressor::kNoCompress),
compress_size_(0) {}
Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)
: num_records_(num), checksum_(sum), compressor_(c), compress_size_(cs) {}
void Header::Parse(std::istream& is) {
is.read(reinterpret_cast<char*>(&num_records_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&checksum_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&compressor_), sizeof(uint32_t))
.read(reinterpret_cast<char*>(&compress_size_), sizeof(uint32_t));
}
void Header::Write(std::ostream& os) const {
os.write(reinterpret_cast<const char*>(&num_records_), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&checksum_), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&compressor_), sizeof(uint32_t))
.write(reinterpret_cast<const char*>(&compress_size_), sizeof(uint32_t));
}
std::ostream& operator<<(std::ostream& os, Header h) {
os << h.NumRecords() << h.Checksum()
<< static_cast<uint32_t>(h.CompressType()) << h.CompressSize();
return os;
}
bool operator==(Header l, Header r) {
return l.NumRecords() == r.NumRecords() && l.Checksum() == r.Checksum() &&
l.CompressType() == r.CompressType() &&
l.CompressSize() == r.CompressSize();
}
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sstream>
namespace paddle {
namespace recordio {
// Default ChunkSize
constexpr size_t kDefaultMaxChunkSize = 32 * 1024 * 1024;
// MagicNumber for memory checking
constexpr uint32_t kMagicNumber = 0x01020304;
enum class Compressor : uint32_t {
// NoCompression means writing raw chunk data into files.
// With other choices, chunks are compressed before written.
kNoCompress = 0,
// Snappy had been the default compressing algorithm widely
// used in Google. It compromises between speech and
// compression ratio.
kSnappy = 1,
// Gzip is a well-known compression algorithm. It is
// recommmended only you are looking for compression ratio.
kGzip = 2,
};
// Header is the metadata of Chunk
class Header {
public:
Header();
Header(uint32_t num, uint32_t sum, Compressor ct, uint32_t cs);
void Write(std::ostream& os) const;
void Parse(std::istream& is);
uint32_t NumRecords() const { return num_records_; }
uint32_t Checksum() const { return checksum_; }
Compressor CompressType() const { return compressor_; }
uint32_t CompressSize() const { return compress_size_; }
private:
uint32_t num_records_;
uint32_t checksum_;
Compressor compressor_;
uint32_t compress_size_;
};
// Allow Header Loggable
std::ostream& operator<<(std::ostream& os, Header h);
bool operator==(Header l, Header r);
} // namespace recordio
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/recordio/header.h"
#include <sstream>
#include "gtest/gtest.h"
using namespace paddle::recordio;
TEST(Recordio, ChunkHead) {
Header hdr(0, 1, Compressor::kGzip, 3);
std::stringstream ss;
hdr.Write(ss);
ss.seekg(0, std::ios::beg);
Header hdr2;
hdr2.Parse(ss);
EXPECT_TRUE(hdr == hdr2);
}
...@@ -14,78 +14,4 @@ make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs ...@@ -14,78 +14,4 @@ make -j `nproc` paddle_docs paddle_docs_cn paddle_api_docs
# check websites for broken links # check websites for broken links
linkchecker doc/v2/en/html/index.html linkchecker doc/v2/en/html/index.html
linkchecker doc/v2/cn/html/index.html linkchecker doc/v2/cn/html/index.html
linkchecker doc/api/en/html/index.html linkchecker doc/v2/api/en/html/index.html
# Parse Github URL
REPO=`git config remote.origin.url`
SSH_REPO=${REPO/https:\/\/github.com\//git@github.com:}
SHA=`git rev-parse --verify HEAD`
# Documentation branch name
# gh-pages branch is used for PaddlePaddle.org. The English version of
# documentation in `doc` directory, and the chinese version in `doc_cn`
# directory.
TARGET_BRANCH="gh-pages"
# Only deploy master branch to build latest documentation.
SOURCE_BRANCH="master"
# Clone the repo to output directory
mkdir output
git clone $REPO output
cd output
function deploy_docs() {
SOURCE_BRANCH=$1
DIR=$2
# If is not a Github pull request
if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then
exit 0
fi
# If it is not watched branch.
if [ "$TRAVIS_BRANCH" != "$SOURCE_BRANCH" ]; then
return
fi
# checkout github page branch
git checkout $TARGET_BRANCH || git checkout --orphan $TARGET_BRANCH
mkdir -p ${DIR}
# remove old docs. mv new docs.
set +e
rm -rf ${DIR}/doc ${DIR}/doc_cn ${DIR}/api_doc
set -e
cp -r ../doc/v2/cn/html ${DIR}/doc_cn
cp -r ../doc/v2/en/html ${DIR}/doc
cp -r ../doc/api/en/html ${DIR}/api_doc
git add .
}
deploy_docs "master" "."
deploy_docs "develop" "./develop/"
# Check is there anything changed.
set +e
git diff --cached --exit-code >/dev/null
if [ $? -eq 0 ]; then
echo "No changes to the output on this push; exiting."
exit 0
fi
set -e
if [ -n $SSL_KEY ]; then # Only push updated docs for github.com/PaddlePaddle/Paddle.
# Commit
git add .
git config user.name "Travis CI"
git config user.email "paddle-dev@baidu.com"
git commit -m "Deploy to GitHub Pages: ${SHA}"
# Set ssh private key
openssl aes-256-cbc -K $SSL_KEY -iv $SSL_IV -in ../../paddle/scripts/travis/deploy_key.enc -out deploy_key -d
chmod 600 deploy_key
eval `ssh-agent -s`
ssh-add deploy_key
# Push
git push $SSH_REPO $TARGET_BRANCH
fi
...@@ -28,6 +28,7 @@ import nets ...@@ -28,6 +28,7 @@ import nets
import optimizer import optimizer
import backward import backward
import regularizer import regularizer
import average
from param_attr import ParamAttr, WeightNormParamAttr from param_attr import ParamAttr, WeightNormParamAttr
from data_feeder import DataFeeder from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace from core import LoDTensor, CPUPlace, CUDAPlace
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
"""
Class of all kinds of Average.
All Averages are accomplished via Python totally.
They do not change Paddle's Program, nor do anything to
modify NN model's configuration. They are completely
wrappers of Python functions.
"""
def _is_number_(var):
return isinstance(var, int) or isinstance(var, float) or (isinstance(
var, np.ndarray) and var.shape == (1, ))
def _is_number_or_matrix_(var):
return _is_number_(var) or isinstance(var, np.ndarray)
class WeightedAverage(object):
def __init__(self):
self.reset()
def reset(self):
self.numerator = None
self.denominator = None
def add(self, value, weight):
if not _is_number_or_matrix_(value):
raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.")
if not _is_number_(weight):
raise ValueError("The 'weight' must be a number(int, float).")
if self.numerator is None or self.denominator is None:
self.numerator = value * weight
self.denominator = weight
else:
self.numerator += value * weight
self.denominator += weight
def eval(self):
if self.numerator is None or self.denominator is None:
raise ValueError(
"There is no data to be averaged in WeightedAverage.")
return self.numerator / self.denominator
...@@ -486,7 +486,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -486,7 +486,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
params_and_grads = [] params_and_grads = []
for param in parameters: for param in parameters:
if param not in grad_info_map: if param not in grad_info_map:
raise ValueError("param %s is not in map" % param) continue
grad_info = grad_info_map[param] grad_info = grad_info_map[param]
grad_block = grad_info[1] grad_block = grad_info[1]
if not grad_block.has_var(grad_info[0]): if not grad_block.has_var(grad_info[0]):
......
...@@ -108,44 +108,6 @@ class Evaluator(object): ...@@ -108,44 +108,6 @@ class Evaluator(object):
return state return state
class Accuracy(Evaluator):
"""
Average Accuracy for multiple mini-batches.
"""
def __init__(self, input, label, k=1, **kwargs):
super(Accuracy, self).__init__("accuracy", **kwargs)
main_program = self.helper.main_program
if main_program.current_block().idx != 0:
raise ValueError("You can only invoke Evaluator in root block")
self.total = self.create_state(dtype='int64', shape=[1], suffix='total')
self.correct = self.create_state(
dtype='int64', shape=[1], suffix='correct')
total = self.helper.create_tmp_variable(dtype='int')
correct = self.helper.create_tmp_variable(dtype='int')
acc = layers.accuracy(
input=input, label=label, k=k, total=total, correct=correct)
total = layers.cast(x=total, dtype='int64')
correct = layers.cast(x=correct, dtype='int64')
layers.sums(input=[self.total, total], out=self.total)
layers.sums(input=[self.correct, correct], out=self.correct)
self.metrics.append(acc)
def eval(self, executor, eval_program=None):
if eval_program is None:
eval_program = Program()
block = eval_program.current_block()
with program_guard(main_program=eval_program):
total = _clone_var_(block, self.total)
correct = _clone_var_(block, self.correct)
total = layers.cast(total, dtype='float32')
correct = layers.cast(correct, dtype='float32')
out = layers.elementwise_div(x=correct, y=total)
return np.array(executor.run(eval_program, fetch_list=[out])[0])
class ChunkEvaluator(Evaluator): class ChunkEvaluator(Evaluator):
""" """
Accumulate counter numbers output by chunk_eval from mini-batches and Accumulate counter numbers output by chunk_eval from mini-batches and
...@@ -312,6 +274,10 @@ class DetectionMAP(Evaluator): ...@@ -312,6 +274,10 @@ class DetectionMAP(Evaluator):
bounding box (bbox), which is a LoDTensor [N, 1]. bounding box (bbox), which is a LoDTensor [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax]. LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
class_num (int): The class number.
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all categories will be
considered, 0 by defalut.
overlap_threshold (float): The threshold for deciding true/false overlap_threshold (float): The threshold for deciding true/false
positive, 0.5 by defalut. positive, 0.5 by defalut.
evaluate_difficult (bool): Whether to consider difficult ground truth evaluate_difficult (bool): Whether to consider difficult ground truth
...@@ -345,6 +311,8 @@ class DetectionMAP(Evaluator): ...@@ -345,6 +311,8 @@ class DetectionMAP(Evaluator):
gt_label, gt_label,
gt_box, gt_box,
gt_difficult, gt_difficult,
class_num,
background_label=0,
overlap_threshold=0.5, overlap_threshold=0.5,
evaluate_difficult=True, evaluate_difficult=True,
ap_version='integral'): ap_version='integral'):
...@@ -358,6 +326,8 @@ class DetectionMAP(Evaluator): ...@@ -358,6 +326,8 @@ class DetectionMAP(Evaluator):
map = layers.detection_map( map = layers.detection_map(
input, input,
label, label,
class_num,
background_label,
overlap_threshold=overlap_threshold, overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult, evaluate_difficult=evaluate_difficult,
ap_version=ap_version) ap_version=ap_version)
...@@ -377,6 +347,8 @@ class DetectionMAP(Evaluator): ...@@ -377,6 +347,8 @@ class DetectionMAP(Evaluator):
accum_map = layers.detection_map( accum_map = layers.detection_map(
input, input,
label, label,
class_num,
background_label,
overlap_threshold=overlap_threshold, overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult, evaluate_difficult=evaluate_difficult,
has_state=self.has_state, has_state=self.has_state,
......
...@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True): ...@@ -163,6 +163,22 @@ def fetch_var(name, scope=None, return_numpy=True):
return tensor return tensor
def get_program_cache_key(feed, fetch_list):
feed_var_names = feed.keys()
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
else:
raise TypeError(str(var) + " should be Variable or str")
fetch_var_names = map(to_name_str, fetch_list)
return str(feed_var_names + fetch_var_names)
class Executor(object): class Executor(object):
def __init__(self, places): def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple): if not isinstance(places, list) and not isinstance(places, tuple):
...@@ -177,6 +193,7 @@ class Executor(object): ...@@ -177,6 +193,7 @@ class Executor(object):
# TODO(dzhwinter) : only use the first place # TODO(dzhwinter) : only use the first place
self.executor = core.Executor(act_places[0]) self.executor = core.Executor(act_places[0])
self.places = places self.places = places
self.program_caches = dict()
def aslodtensor(self, data): def aslodtensor(self, data):
def accumulate(data): def accumulate(data):
...@@ -225,9 +242,30 @@ class Executor(object): ...@@ -225,9 +242,30 @@ class Executor(object):
feed_var_name='feed', feed_var_name='feed',
fetch_var_name='fetch', fetch_var_name='fetch',
scope=None, scope=None,
return_numpy=True): return_numpy=True,
use_program_cache=False):
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used.
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
to this list.
:param feed_var_name: the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
:param return_numpy: if convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
:return: result according to fetch_list.
"""
if feed is None: if feed is None:
feed = {} feed = {}
if not isinstance(feed, dict):
raise TypeError("feed should be a map")
if fetch_list is None: if fetch_list is None:
fetch_list = [] fetch_list = []
...@@ -240,35 +278,64 @@ class Executor(object): ...@@ -240,35 +278,64 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
program = program.clone() program_cache = None
global_block = program.global_block() program_cache_key = get_program_cache_key(feed, fetch_list)
if feed_var_name in global_block.vars: if use_program_cache:
feed_var = global_block.var(feed_var_name) # find program cache by cache_key
program_cache = self.program_caches.get(program_cache_key, None)
# TODO(qiao): Should check program_cache and program are exactly the same.
else: else:
feed_var = global_block.create_var( self.program_caches.pop(program_cache_key, None)
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars: if program_cache is None:
fetch_var = global_block.var(fetch_var_name) program_cache = program.clone()
else:
fetch_var = global_block.create_var( if use_program_cache:
name=fetch_var_name, self.program_caches[program_cache_key] = program_cache
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True) global_block = program_cache.global_block()
if not has_feed_operators(global_block, feed, feed_var_name): if feed_var_name in global_block.vars:
for i, name in enumerate(feed): feed_var = global_block.var(feed_var_name)
out = global_block.var(name) else:
global_block.prepend_op( feed_var = global_block.create_var(
type='feed', name=feed_var_name,
inputs={'X': [feed_var]}, type=core.VarDesc.VarType.FEED_MINIBATCH,
outputs={'Out': [out]}, persistable=True)
attrs={'col': i})
if fetch_var_name in global_block.vars:
for op in global_block.ops: fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed': if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0] feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name] cur_feed = feed[feed_target_name]
...@@ -279,17 +346,7 @@ class Executor(object): ...@@ -279,17 +346,7 @@ class Executor(object):
else: else:
break break
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): self.executor.run(program_cache.desc, scope, 0, True, True)
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program.desc, scope, 0, True, True)
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list)) for i in xrange(len(fetch_list))
......
...@@ -28,6 +28,8 @@ import math_op_patch ...@@ -28,6 +28,8 @@ import math_op_patch
from math_op_patch import * from math_op_patch import *
import detection import detection
from detection import * from detection import *
import metric
from metric import *
from learning_rate_scheduler import * from learning_rate_scheduler import *
__all__ = [] __all__ = []
...@@ -39,4 +41,5 @@ __all__ += control_flow.__all__ ...@@ -39,4 +41,5 @@ __all__ += control_flow.__all__
__all__ += ops.__all__ __all__ += ops.__all__
__all__ += device.__all__ __all__ += device.__all__
__all__ += detection.__all__ __all__ += detection.__all__
__all__ += metric.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
...@@ -151,6 +151,8 @@ def detection_output(loc, ...@@ -151,6 +151,8 @@ def detection_output(loc,
@autodoc() @autodoc()
def detection_map(detect_res, def detection_map(detect_res,
label, label,
class_num,
background_label=0,
overlap_threshold=0.3, overlap_threshold=0.3,
evaluate_difficult=True, evaluate_difficult=True,
has_state=None, has_state=None,
...@@ -192,7 +194,8 @@ def detection_map(detect_res, ...@@ -192,7 +194,8 @@ def detection_map(detect_res,
attrs={ attrs={
'overlap_threshold': overlap_threshold, 'overlap_threshold': overlap_threshold,
'evaluate_difficult': evaluate_difficult, 'evaluate_difficult': evaluate_difficult,
'ap_type': ap_version 'ap_type': ap_version,
'class_num': class_num,
}) })
return map_out return map_out
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
All layers just related to metric.
"""
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable
from ..param_attr import ParamAttr
__all__ = ['accuracy']
def accuracy(input, label, k=1, correct=None, total=None):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
"""
helper = LayerHelper("accuracy", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
acc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="accuracy",
inputs={
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
},
outputs={
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total],
})
return acc_out
...@@ -35,7 +35,6 @@ __all__ = [ ...@@ -35,7 +35,6 @@ __all__ = [
'cos_sim', 'cos_sim',
'cross_entropy', 'cross_entropy',
'square_error_cost', 'square_error_cost',
'accuracy',
'chunk_eval', 'chunk_eval',
'sequence_conv', 'sequence_conv',
'conv2d', 'conv2d',
...@@ -1022,40 +1021,6 @@ def square_error_cost(input, label): ...@@ -1022,40 +1021,6 @@ def square_error_cost(input, label):
return square_out return square_out
def accuracy(input, label, k=1, correct=None, total=None):
"""
This function computes the accuracy using the input and label.
The output is the top_k inputs and their indices.
"""
helper = LayerHelper("accuracy", **locals())
topk_out = helper.create_tmp_variable(dtype=input.dtype)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
acc_out = helper.create_tmp_variable(dtype="float32")
if correct is None:
correct = helper.create_tmp_variable(dtype="int64")
if total is None:
total = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="accuracy",
inputs={
"Out": [topk_out],
"Indices": [topk_indices],
"Label": [label]
},
outputs={
"Accuracy": [acc_out],
"Correct": [correct],
"Total": [total],
})
return acc_out
def chunk_eval(input, def chunk_eval(input,
label, label,
chunk_scheme, chunk_scheme,
...@@ -1438,6 +1403,7 @@ def pool2d(input, ...@@ -1438,6 +1403,7 @@ def pool2d(input,
pool_padding=0, pool_padding=0,
global_pooling=False, global_pooling=False,
use_cudnn=True, use_cudnn=True,
ceil_mode=False,
name=None): name=None):
""" """
This function adds the operator for pooling in 2 dimensions, using the This function adds the operator for pooling in 2 dimensions, using the
...@@ -1474,7 +1440,8 @@ def pool2d(input, ...@@ -1474,7 +1440,8 @@ def pool2d(input,
"global_pooling": global_pooling, "global_pooling": global_pooling,
"strides": pool_stride, "strides": pool_stride,
"paddings": pool_padding, "paddings": pool_padding,
"use_cudnn": use_cudnn "use_cudnn": use_cudnn,
"ceil_mode": ceil_mode
}) })
return pool_out return pool_out
...@@ -3180,7 +3147,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): ...@@ -3180,7 +3147,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
data = fluid.layers.data(name='data', shape=[128], dtype='float32') data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[100], dtype='int64') label = fluid.layers.data(name='label', shape=[100], dtype='int64')
fc = fluid.layers.fc(input=data, size=100) fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.smooth_l1(logits=fc, label=label) out = fluid.layers.smooth_l1(x=fc, y=label)
""" """
helper = LayerHelper('smooth_l1_loss', **locals()) helper = LayerHelper('smooth_l1_loss', **locals())
diff = helper.create_tmp_variable(dtype=x.dtype) diff = helper.create_tmp_variable(dtype=x.dtype)
......
...@@ -122,7 +122,8 @@ avg_cost = fluid.layers.mean(cost) ...@@ -122,7 +122,8 @@ avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer = fluid.optimizer.Adam(learning_rate=0.001)
opts = optimizer.minimize(avg_cost) opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label) batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size)
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
...@@ -144,13 +145,17 @@ feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) ...@@ -144,13 +145,17 @@ feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
i = 0 i = 0
accuracy = fluid.average.WeightedAverage()
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
accuracy.reset(exe) accuracy.reset()
for data in train_reader(): for data in train_reader():
loss, acc = exe.run(fluid.default_main_program(), loss, acc, weight = exe.run(
feed=feeder.feed(data), fluid.default_main_program(),
fetch_list=[avg_cost] + accuracy.metrics) feed=feeder.feed(data),
pass_acc = accuracy.eval(exe) fetch_list=[avg_cost, batch_acc, batch_size])
accuracy.add(value=acc, weight=weight)
pass_acc = accuracy.eval()
print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str(
pass_acc)) pass_acc))
# this model is slow, so if we can train two mini batch, we think it works properly. # this model is slow, so if we can train two mini batch, we think it works properly.
......
...@@ -158,7 +158,7 @@ class TestDetectionMAP(unittest.TestCase): ...@@ -158,7 +158,7 @@ class TestDetectionMAP(unittest.TestCase):
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
map_out = layers.detection_map(detect_res=detect_res, label=label) map_out = layers.detection_map(detect_res, label, 21)
self.assertIsNotNone(map_out) self.assertIsNotNone(map_out)
self.assertEqual(map_out.shape, (1, )) self.assertEqual(map_out.shape, (1, ))
print(str(program)) print(str(program))
......
...@@ -22,8 +22,8 @@ from op_test import OpTest ...@@ -22,8 +22,8 @@ from op_test import OpTest
class TestDetectionMAPOp(OpTest): class TestDetectionMAPOp(OpTest):
def set_data(self): def set_data(self):
self.class_num = 4
self.init_test_case() self.init_test_case()
self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)] self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)]
self.label = np.array(self.label).astype('float32') self.label = np.array(self.label).astype('float32')
self.detect = np.array(self.detect).astype('float32') self.detect = np.array(self.detect).astype('float32')
...@@ -53,7 +53,8 @@ class TestDetectionMAPOp(OpTest): ...@@ -53,7 +53,8 @@ class TestDetectionMAPOp(OpTest):
self.attrs = { self.attrs = {
'overlap_threshold': self.overlap_threshold, 'overlap_threshold': self.overlap_threshold,
'evaluate_difficult': self.evaluate_difficult, 'evaluate_difficult': self.evaluate_difficult,
'ap_type': self.ap_type 'ap_type': self.ap_type,
'class_num': self.class_num
} }
self.out_class_pos_count = np.array(self.out_class_pos_count).astype( self.out_class_pos_count = np.array(self.out_class_pos_count).astype(
...@@ -126,12 +127,7 @@ class TestDetectionMAPOp(OpTest): ...@@ -126,12 +127,7 @@ class TestDetectionMAPOp(OpTest):
return class_pos_count_dict, true_pos_dict, false_pos_dict return class_pos_count_dict, true_pos_dict, false_pos_dict
def get_output_pos(label_count, true_pos, false_pos): def get_output_pos(label_count, true_pos, false_pos):
max_label = 0 label_number = self.class_num
for (label, label_pos_num) in label_count.items():
if max_label < label:
max_label = label
label_number = max_label + 1
out_class_pos_count = [] out_class_pos_count = []
out_true_pos_lod = [0] out_true_pos_lod = [0]
...@@ -220,11 +216,16 @@ class TestDetectionMAPOp(OpTest): ...@@ -220,11 +216,16 @@ class TestDetectionMAPOp(OpTest):
mAP += average_precisions mAP += average_precisions
count += 1 count += 1
self.out_class_pos_count, self.out_true_pos, self.out_true_pos_lod, self.out_false_pos, self.out_false_pos_lod = get_output_pos( pcnt, tp, tp_lod, fp, fp_lod = get_output_pos(label_count, true_pos,
label_count, true_pos, false_pos) false_pos)
self.out_class_pos_count = pcnt
self.out_true_pos = tp
self.out_true_pos_lod = tp_lod
self.out_false_pos = fp
self.out_false_pos_lod = fp_lod
if count != 0: if count != 0:
mAP /= count mAP /= count
return mAP * 100.0 return mAP
def setUp(self): def setUp(self):
self.op_type = "detection_map" self.op_type = "detection_map"
......
...@@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -89,7 +89,7 @@ class TestLearningRateDecay(unittest.TestCase):
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for step in range(10): for step in range(10):
lr_val, = exe.run(fluid.default_main_program(), lr_val, = exe.run(fluid.default_main_program(),
feed=[], feed={},
fetch_list=[decayed_lr]) fetch_list=[decayed_lr])
python_decayed_lr = python_decay_fn( python_decayed_lr = python_decay_fn(
global_step=float(step), **kwargs) global_step=float(step), **kwargs)
......
...@@ -19,12 +19,21 @@ import paddle.fluid.core as core ...@@ -19,12 +19,21 @@ import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): def max_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False):
N, C, H, W = x.shape N, C, H, W = x.shape
if global_pool == 1: if global_pool == 1:
ksize = [H, W] ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out)) out = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out): for i in xrange(H_out):
for j in xrange(W_out): for j in xrange(W_out):
...@@ -38,12 +47,21 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): ...@@ -38,12 +47,21 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
return out return out
def avg_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0): def avg_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False):
N, C, H, W = x.shape N, C, H, W = x.shape
if global_pool == 1: if global_pool == 1:
ksize = [H, W] ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out)) out = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out): for i in xrange(H_out):
for j in xrange(W_out): for j in xrange(W_out):
...@@ -65,12 +83,13 @@ class TestPool2d_Op(OpTest): ...@@ -65,12 +83,13 @@ class TestPool2d_Op(OpTest):
self.init_global_pool() self.init_global_pool()
self.init_op_type() self.init_op_type()
self.init_pool_type() self.init_pool_type()
self.init_ceil_mode()
if self.global_pool: if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))] self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool2D_forward_naive(input, self.ksize, self.strides, output = self.pool2D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.paddings, self.global_pool,
self.global_pool).astype("float32") self.ceil_mode).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = { self.attrs = {
...@@ -80,6 +99,7 @@ class TestPool2d_Op(OpTest): ...@@ -80,6 +99,7 @@ class TestPool2d_Op(OpTest):
'pooling_type': self.pool_type, 'pooling_type': self.pool_type,
'global_pooling': self.global_pool, 'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'ceil_mode': self.ceil_mode,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
} }
...@@ -116,6 +136,9 @@ class TestPool2d_Op(OpTest): ...@@ -116,6 +136,9 @@ class TestPool2d_Op(OpTest):
def init_global_pool(self): def init_global_pool(self):
self.global_pool = True self.global_pool = True
def init_ceil_mode(self):
self.ceil_mode = False
class TestCase1(TestPool2d_Op): class TestCase1(TestPool2d_Op):
def init_test_case(self): def init_test_case(self):
...@@ -217,5 +240,25 @@ class TestCUDNNCase6(TestCase5): ...@@ -217,5 +240,25 @@ class TestCUDNNCase6(TestCase5):
self.op_type = "pool2d" self.op_type = "pool2d"
class TestCeilModeCase1(TestCUDNNCase1):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase2(TestCUDNNCase2):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase3(TestCase1):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase4(TestCase2):
def init_ceil_mode(self):
self.ceil_mode = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -19,13 +19,24 @@ import paddle.fluid.core as core ...@@ -19,13 +19,24 @@ import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): def max_pool3D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False):
N, C, D, H, W = x.shape N, C, D, H, W = x.shape
if global_pool == 1: if global_pool == 1:
ksize = [D, H, W] ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) / strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 *
paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out)) out = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out): for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0)) d_start = np.max((k * strides[0] - paddings[0], 0))
...@@ -42,13 +53,24 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): ...@@ -42,13 +53,24 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0):
return out return out
def avg_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0): def avg_pool3D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False):
N, C, D, H, W = x.shape N, C, D, H, W = x.shape
if global_pool == 1: if global_pool == 1:
ksize = [D, H, W] ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 D_out = (D - ksize[0] + 2 * paddings[0] + strides[0] - 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 ) / strides[0] + 1 if ceil_mode else (H - ksize[0] + 2 *
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1] + strides[1] - 1
) / strides[1] + 1 if ceil_mode else (W - ksize[1] + 2 *
paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2] + strides[2] - 1
) / strides[2] + 1 if ceil_mode else (W - ksize[2] + 2 *
paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out)) out = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out): for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0)) d_start = np.max((k * strides[0] - paddings[0], 0))
...@@ -73,13 +95,14 @@ class TestPool3d_Op(OpTest): ...@@ -73,13 +95,14 @@ class TestPool3d_Op(OpTest):
self.init_global_pool() self.init_global_pool()
self.init_op_type() self.init_op_type()
self.init_pool_type() self.init_pool_type()
self.init_ceil_mode()
if self.global_pool: if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))] self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool3D_forward_naive(input, self.ksize, self.strides, output = self.pool3D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.paddings, self.global_pool,
self.global_pool).astype("float32") self.ceil_mode).astype("float32")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = { self.attrs = {
...@@ -89,6 +112,7 @@ class TestPool3d_Op(OpTest): ...@@ -89,6 +112,7 @@ class TestPool3d_Op(OpTest):
'pooling_type': self.pool_type, 'pooling_type': self.pool_type,
'global_pooling': self.global_pool, 'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'ceil_mode': self.ceil_mode,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
} }
...@@ -125,6 +149,9 @@ class TestPool3d_Op(OpTest): ...@@ -125,6 +149,9 @@ class TestPool3d_Op(OpTest):
def init_global_pool(self): def init_global_pool(self):
self.global_pool = True self.global_pool = True
def init_ceil_mode(self):
self.ceil_mode = False
class TestCase1(TestPool3d_Op): class TestCase1(TestPool3d_Op):
def init_test_case(self): def init_test_case(self):
...@@ -227,5 +254,25 @@ class TestCUDNNCase6(TestCase5): ...@@ -227,5 +254,25 @@ class TestCUDNNCase6(TestCase5):
self.op_type = "pool3d" self.op_type = "pool3d"
class TestCeilModeCase1(TestCUDNNCase1):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase2(TestCUDNNCase2):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase3(TestCase1):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase4(TestCase2):
def init_ceil_mode(self):
self.ceil_mode = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -37,7 +37,9 @@ class TestProfiler(unittest.TestCase): ...@@ -37,7 +37,9 @@ class TestProfiler(unittest.TestCase):
label = fluid.layers.data(name='y', shape=[1], dtype='int64') label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label) batch_size = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size)
optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
opts = optimizer.minimize(avg_cost, startup_program=startup_program) opts = optimizer.minimize(avg_cost, startup_program=startup_program)
...@@ -46,7 +48,7 @@ class TestProfiler(unittest.TestCase): ...@@ -46,7 +48,7 @@ class TestProfiler(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program) exe.run(startup_program)
accuracy.reset(exe) pass_acc_calculator = fluid.average.WeightedAverage()
with profiler.profiler(state, 'total', profile_path) as prof: with profiler.profiler(state, 'total', profile_path) as prof:
for iter in range(10): for iter in range(10):
if iter == 2: if iter == 2:
...@@ -57,9 +59,11 @@ class TestProfiler(unittest.TestCase): ...@@ -57,9 +59,11 @@ class TestProfiler(unittest.TestCase):
outs = exe.run(main_program, outs = exe.run(main_program,
feed={'x': x, feed={'x': x,
'y': y}, 'y': y},
fetch_list=[avg_cost] + accuracy.metrics) fetch_list=[avg_cost, batch_acc, batch_size])
acc = np.array(outs[1]) acc = np.array(outs[1])
pass_acc = accuracy.eval(exe) b_size = np.array(outs[2])
pass_acc_calculator.add(value=acc, weight=b_size)
pass_acc = pass_acc_calculator.eval()
def test_cpu_profiler(self): def test_cpu_profiler(self):
self.net_profiler('CPU') self.net_profiler('CPU')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册