提交 4801ee8f 编写于 作者: D Dang Qingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into refine_generate_proposals_op

......@@ -25,5 +25,6 @@ third_party/
bazel-*
third_party/
build_*
# clion workspace.
cmake-build-*
......@@ -72,6 +72,7 @@ option(WITH_INFERENCE "Compile fluid inference library" ON)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
# PY_VERSION
if(NOT PY_VERSION)
......@@ -126,6 +127,9 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
"A path setting fluid shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries")
if (WITH_C_API AND WITH_PYTHON)
message(WARNING "It is suggest not embedded a python interpreter in Paddle "
"when using C-API. It will give an unpredictable behavior when using a "
......
......@@ -24,6 +24,7 @@ COPY ./paddle/scripts/docker/root/ /root/
RUN apt-get update && \
apt-get install -y --allow-downgrades patchelf \
python3 python3-dev python3-pip \
git python-pip python-dev python-opencv openssh-server bison \
libnccl2=2.1.2-1+cuda8.0 libnccl-dev=2.1.2-1+cuda8.0 \
wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \
......@@ -70,24 +71,33 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
# version(1.7.1 for now), which causes building documentation failed.
RUN easy_install -U pip && \
RUN pip3 install -U wheel && \
pip3 install -U docopt PyYAML sphinx==1.5.6 && \
pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \
pip install -U wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip install pre-commit 'ipython==5.3.0' && \
RUN pip3 install pre-commit 'ipython==5.3.0' && \
pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3 install opencv-python && \
pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python
#For docstring checker
RUN pip3 install pylint pytest astroid isort
RUN pip install pylint pytest astroid isort LinkChecker
COPY ./python/requirements.txt /root/
RUN pip3 install -r /root/requirements.txt
RUN pip install -r /root/requirements.txt
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
# the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2
RUN apt-get install -y libssl-dev libffi-dev
RUN pip3 install certifi urllib3[secure]
RUN pip install certifi urllib3[secure]
......
......@@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.0/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
......@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0)
### Latest PaddlePaddle Release: [Fluid 1.0.1](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0)
### Install Latest Stable Release:
```
# Linux CPU
......@@ -27,9 +27,9 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==0.15.0.post87
pip install paddlepaddle-gpu==1.0.1.post87
# Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==0.15.0.post85
pip install paddlepaddle-gpu==1.0.1.post85
# For installation on other platform, refer to http://paddlepaddle.org/
```
......@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.15.0.post85
## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website.
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) on our website.
## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation.
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.0.0/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html)
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.0/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html)
- [Python API](http://paddlepaddle.org/documentation/api/zh/1.0/fluid.html)
Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html)
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.0/advanced_usage/development/contribute_to_paddle.html)
We appreciate your contributions!
......
文件模式从 100644 更改为 100755
......@@ -40,7 +40,7 @@ set(OPENBLAS_LIB_SEARCH_PATHS
/usr/local/opt/openblas/lib)
find_path(OPENBLAS_INC_DIR NAMES cblas.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS} NO_DEFAULT_PATH)
find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h
PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
find_library(OPENBLAS_LIB NAMES openblas
......
......@@ -175,7 +175,10 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11")
list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
endif(NOT WIN32)
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
if(WITH_FAST_MATH)
# Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
list(APPEND CUDA_NVCC_FLAGS "--use_fast_math")
endif()
# in cuda9, suppress cuda warning on eigen
list(APPEND CUDA_NVCC_FLAGS "-w")
# Set :expt-relaxed-constexpr to suppress Eigen warnings
......
......@@ -3,6 +3,14 @@ INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3)
INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
if(NOT WITH_FAST_MATH)
# EIGEN_FAST_MATH: https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html
# enables some optimizations which might affect the accuracy of the result.
# This currently enables the SSE vectorization of sin() and cos(),
# and speedups sqrt() for single precision.
# Defined to 1 by default. Define it to 0 to disable.
add_definitions(-DEIGEN_FAST_MATH=0)
endif()
if(WITH_AMD_GPU)
ExternalProject_Add(
......
......@@ -27,7 +27,7 @@ IF(NOT ${CBLAS_FOUND})
SET(CBLAS_SOURCES_DIR ${THIRD_PARTY_PATH}/openblas)
SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_INCLUDE_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE)
SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE)
SET(CBLAS_LIBRARIES
"${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}"
......@@ -96,7 +96,7 @@ IF(NOT ${CBLAS_FOUND})
ENDIF(NOT WIN32)
SET(CBLAS_PROVIDER openblas)
IF(WITH_C_API)
INSTALL(DIRECTORY ${CBLAS_INCLUDE_DIR} DESTINATION third_party/openblas)
INSTALL(DIRECTORY ${CBLAS_INC_DIR} DESTINATION third_party/openblas)
# Because libopenblas.a is a symbolic link of another library, thus need to
# install the whole directory.
IF(ANDROID)
......@@ -117,8 +117,8 @@ IF(NOT ${CBLAS_FOUND})
ENDIF(NOT ${CBLAS_FOUND})
MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}")
MESSAGE(STATUS "BLAS Include: ${CBLAS_INCLUDE_DIR}")
INCLUDE_DIRECTORIES(${CBLAS_INCLUDE_DIR})
MESSAGE(STATUS "BLAS Include: ${CBLAS_INC_DIR}")
INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
# FIXME(gangliao): generate cblas target to track all high performance
# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
......
......@@ -157,6 +157,8 @@ if (APPLE)
# On Mac OS X build fat binaries with x86_64 architectures by default.
set (CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architectures for OSX" FORCE)
endif()
# On Mac OS X register class specifier is deprecated and will cause warning error on latest clang 10.0
set (COMMON_FLAGS -Wno-deprecated-register)
endif(APPLE)
if(LINUX)
......
......@@ -311,6 +311,8 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
endfunction(cc_test)
......@@ -629,6 +631,8 @@ function(py_test TARGET_NAME)
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif()
endfunction()
......
......@@ -18,7 +18,7 @@ function(copy TARGET)
set(oneValueArgs "")
set(multiValueArgs SRCS DSTS DEPS)
cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(inference_lib_dist_dep ${TARGET} ${inference_lib_dist_dep} PARENT_SCOPE)
set(fluid_lib_dist_dep ${TARGET} ${fluid_lib_dist_dep} PARENT_SCOPE)
list(LENGTH copy_lib_SRCS copy_lib_SRCS_len)
list(LENGTH copy_lib_DSTS copy_lib_DSTS_len)
......@@ -150,16 +150,16 @@ if (WITH_ANAKIN AND WITH_MKL)
SRCS
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/libinference_anakin_api* # compiled anakin api
${ANAKIN_INSTALL_DIR} # anakin release
DSTS ${dst_dir}/inference/anakin ${FLUID_INSTALL_DIR}/third_party/install/anakin)
DSTS ${FLUID_INSTALL_DIR}/third_party/install/anakin ${FLUID_INSTALL_DIR}/third_party/install/anakin)
list(APPEND inference_deps anakin_inference_lib)
endif()
set(module "inference")
copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${src_dir}/${module}/api/paddle_inference_api.h ${src_dir}/${module}/api/demo_ci
${src_dir}/${module}/api/paddle_inference_api.h
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
)
set(module "platform")
......@@ -185,20 +185,41 @@ copy(cmake_cache
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
DSTS ${FLUID_INSTALL_DIR})
add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep})
# This command generates a complete fluid library for both train and inference
add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep})
# Following commands generate a inference-only fluid library
# third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR}
copy(third_party DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/third_party ${FLUID_INSTALL_DIR}/CMakeCache.txt
DSTS ${FLUID_INFERENCE_INSTALL_DIR} ${FLUID_INFERENCE_INSTALL_DIR}
)
# only need libpaddle_fluid.so/a and paddle_inference_api.h for inference-only library
copy(inference_api_lib DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_inference_api.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
)
add_custom_target(inference_lib_dist DEPENDS third_party inference_api_lib)
# paddle fluid version
execute_process(
COMMAND ${GIT_EXECUTABLE} log --pretty=format:%H -1
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_GIT_COMMIT)
set(version_file ${FLUID_INSTALL_DIR}/version.txt)
file(WRITE ${version_file}
"GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n"
"WITH_MKL: ${WITH_MKL}\n"
"WITH_GPU: ${WITH_GPU}\n")
if(WITH_GPU)
file(APPEND ${version_file}
"CUDA version: ${CUDA_VERSION}\n"
"CUDNN version: v${CUDNN_MAJOR_VERSION}\n")
endif()
function(version version_file)
execute_process(
COMMAND ${GIT_EXECUTABLE} log --pretty=format:%H -1
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_GIT_COMMIT)
file(WRITE ${version_file}
"GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n"
"WITH_MKL: ${WITH_MKL}\n"
"WITH_MKLDNN: ${WITH_MKLDNN}\n"
"WITH_GPU: ${WITH_GPU}\n")
if(WITH_GPU)
file(APPEND ${version_file}
"CUDA version: ${CUDA_VERSION}\n"
"CUDNN version: v${CUDNN_MAJOR_VERSION}\n")
endif()
endfunction()
version(${FLUID_INSTALL_DIR}/version.txt)
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt)
此差异已折叠。
......@@ -12,6 +12,5 @@ endif(NOT WIN32)
if(WITH_INFERENCE)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
add_subdirectory(train)
endif()
add_subdirectory(train)
# windows treat symbolic file as a real file, which is different with unix
# We create a hidden file and compile it instead of origin source file.
function(windows_symbolic TARGET)
......@@ -9,11 +10,23 @@ function(windows_symbolic TARGET)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu)
message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.")
endif()
add_custom_command(OUTPUT .${src}.cu
# only copy the xx.cu to .xx.cu when the content are modified
set(copy_flag 1)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc SOURCE_STR)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu TARGET_STR)
if (SOURCE_STR STREQUAL TARGET_STR)
set(copy_flag 0)
endif()
endif()
if (copy_flag)
add_custom_command(OUTPUT .${src}.cu
COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu
COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu"
COMMENT "create hidden file of ${src}.cu")
add_custom_target(${TARGET} ALL DEPENDS .${src}.cu)
endif(copy_flag)
add_custom_target(${TARGET} ALL DEPENDS .${src}.cu)
endforeach()
endfunction()
......@@ -81,6 +94,8 @@ nv_test(data_device_transform_test SRCS data_device_transform_test.cu
if(WITH_GPU)
if (WIN32)
# windows treat symbolic file as a real file, which is different with unix
# We create a hidden file and compile it instead of origin source file.
windows_symbolic(hidden_file SRCS data_type_transform.cu)
nv_library(data_type_transform SRCS .data_type_transform.cu DEPS tensor)
add_dependencies(data_type_transform hidden_file)
......@@ -149,7 +164,7 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass elementwise_add_op)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
if (NOT WIN32)
......@@ -169,15 +184,8 @@ cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
# cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
if (NOT WIN32)
cc_test(rw_lock_test SRCS rw_lock_test.cc)
endif (NOT WIN32)
# disable test temporarily.
# TODO https://github.com/PaddlePaddle/Paddle/issues/11971
# cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
# channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
# conditional_block_op while_op assign_op print_op executor proto_desc)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <condition_variable> // NOLINT
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
enum class ChannelAction {
SEND = 0,
RECEIVE = 1,
CLOSE = 2,
};
// Channel is the abstract class of buffered and un-buffered channels.
template <typename T>
class Channel {
public:
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual void Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual bool IsClosed() = 0;
virtual void Close() = 0;
virtual ~Channel() {}
virtual void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) = 0;
virtual void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) = 0;
virtual void RemoveFromSendQ(const void* referrer) = 0;
virtual void RemoveFromReceiveQ(const void* referrer) = 0;
};
// Forward declaration of channel implementations.
template <typename T>
class ChannelImpl;
template <typename T>
Channel<T>* MakeChannel(size_t buffer_size) {
return new ChannelImpl<T>(buffer_size);
}
template <typename T>
void CloseChannel(Channel<T>* ch) {
ch->Close();
}
/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class ChannelHolder {
public:
template <typename T>
void Reset(size_t buffer_size) {
holder_.reset(new PlaceholderImpl<T>(buffer_size));
}
template <typename T>
void Send(T* data) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
channel->Send(data);
}
template <typename T>
bool Receive(T* data) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
return channel->Receive(data);
}
bool IsClosed() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->IsClosed();
}
bool CanSend() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->CanSend();
}
bool CanReceive() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->CanReceive();
}
void close() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Close();
}
size_t Cap() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Cap();
}
void Lock() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Lock();
}
void Unlock() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Unlock();
}
template <typename T>
void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
}
template <typename T>
void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
}
void RemoveFromSendQ(const void* referrer) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromSendQ(referrer);
}
void RemoveFromReceiveQ(const void* referrer) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromReceiveQ(referrer);
}
inline bool IsInitialized() const { return holder_ != nullptr; }
inline const std::type_index Type() {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Type();
}
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0;
virtual bool IsClosed() = 0;
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual void RemoveFromSendQ(const void* referrer) = 0;
virtual void RemoveFromReceiveQ(const void* referrer) = 0;
virtual void Close() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
virtual size_t Cap() = 0;
};
template <typename T>
struct PlaceholderImpl : public Placeholder {
explicit PlaceholderImpl(size_t buffer_size)
: type_(std::type_index(typeid(T))) {
channel_.reset(MakeChannel<T>(buffer_size));
}
virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual bool IsClosed() {
if (channel_) {
return channel_->IsClosed();
}
return false;
}
virtual bool CanSend() {
if (channel_) {
return channel_->CanSend();
}
return false;
}
virtual bool CanReceive() {
if (channel_) {
return channel_->CanReceive();
}
return false;
}
virtual void RemoveFromSendQ(const void* referrer) {
if (channel_) {
channel_->RemoveFromSendQ(referrer);
}
}
virtual void RemoveFromReceiveQ(const void* referrer) {
if (channel_) {
channel_->RemoveFromReceiveQ(referrer);
}
}
virtual void Close() {
if (channel_) channel_->Close();
}
virtual size_t Cap() {
if (channel_)
return channel_->Cap();
else
return -1;
}
virtual void Lock() {
if (channel_) channel_->Lock();
}
virtual void Unlock() {
if (channel_) channel_->Unlock();
}
std::unique_ptr<Channel<T>> channel_;
const std::type_index type_;
};
// Pointer to a PlaceholderImpl object
std::unique_ptr<Placeholder> holder_;
};
} // namespace framework
} // namespace paddle
#include "paddle/fluid/framework/channel_impl.h"
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
#include <atomic>
#include <condition_variable> // NOLINT
#include <deque>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
template <typename T>
class ChannelImpl : public paddle::framework::Channel<T> {
friend Channel<T> *paddle::framework::MakeChannel<T>(size_t);
friend void paddle::framework::CloseChannel<T>(Channel<T> *);
public:
virtual bool CanSend();
virtual bool CanReceive();
virtual void Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
virtual void Unlock();
virtual bool IsClosed();
virtual void Close();
explicit ChannelImpl(size_t);
virtual ~ChannelImpl();
virtual void AddToSendQ(const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb);
virtual void AddToReceiveQ(const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb);
virtual void RemoveFromSendQ(const void *referrer);
virtual void RemoveFromReceiveQ(const void *referrer);
private:
struct QueueMessage {
T *data;
std::shared_ptr<std::condition_variable_any> cond;
bool chan_closed = false;
bool completed = false;
const void *referrer; // TODO(thuan): figure out better way to do this
std::function<bool(ChannelAction)> callback;
explicit QueueMessage(T *item)
: data(item), cond(std::make_shared<std::condition_variable_any>()) {}
QueueMessage(T *item, std::shared_ptr<std::condition_variable_any> cond)
: data(item), cond(cond) {}
void Wait(std::unique_lock<std::recursive_mutex> &lock) {
cond->wait(lock, [this]() { return completed; });
}
void Notify() {
completed = true;
cond->notify_all();
}
};
void send_return() {
send_ctr--;
destructor_cond_.notify_all();
}
bool recv_return(bool value) {
recv_ctr--;
destructor_cond_.notify_all();
return value;
}
std::shared_ptr<QueueMessage> get_first_message(
std::deque<std::shared_ptr<QueueMessage>> *queue, ChannelAction action) {
while (!queue->empty()) {
// Check whether this message was added by Select
// If this was added by Select then execute the callback
// to check if you can execute this message. The callback
// can return false if some other case was executed in Select.
// In that case just discard this QueueMessage and process next.
std::shared_ptr<QueueMessage> m = queue->front();
queue->pop_front();
if (m->callback == nullptr || m->callback(action)) return m;
}
return nullptr;
}
size_t cap_;
std::recursive_mutex mu_;
bool closed_;
std::deque<T> buf_;
std::deque<std::shared_ptr<QueueMessage>> recvq;
std::deque<std::shared_ptr<QueueMessage>> sendq;
std::atomic<unsigned> send_ctr{0};
std::atomic<unsigned> recv_ctr{0};
std::condition_variable_any destructor_cond_;
};
template <typename T>
ChannelImpl<T>::ChannelImpl(size_t capacity)
: cap_(capacity), closed_(false), send_ctr(0), recv_ctr(0) {
PADDLE_ENFORCE_GE(capacity, 0);
}
template <typename T>
bool ChannelImpl<T>::CanSend() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return !closed_ && (!recvq.empty() || buf_.size() < cap_);
}
template <typename T>
bool ChannelImpl<T>::CanReceive() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return !(closed_ && buf_.empty()) && (!sendq.empty() || buf_.size() > 0);
}
template <typename T>
void ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed, throw exception
if (closed_) {
send_return();
lock.unlock();
PADDLE_THROW("Cannot send on closed channel");
}
// If there is a receiver, directly pass the value we want
// to send to the receiver, bypassing the channel buffer if any
if (!recvq.empty()) {
std::shared_ptr<QueueMessage> m =
get_first_message(&recvq, ChannelAction::SEND);
if (m != nullptr) {
*(m->data) = std::move(*item);
m->Notify();
send_return();
return;
} else {
Send(item);
send_return();
return;
}
}
// Unbuffered channel will always bypass this
// If buffered channel has space in buffer,
// write the element to the buffer.
if (buf_.size() < cap_) {
// Copy to buffer
buf_.push_back(std::move(*item));
send_return();
return;
}
// Block on channel, because some receiver will complete
// the operation for us
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
if (m->chan_closed) {
send_return();
lock.unlock();
PADDLE_THROW("Cannot send on closed channel");
}
send_return();
}
template <typename T>
bool ChannelImpl<T>::Receive(T *item) {
recv_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
// If channel is closed and buffer is empty or
// channel is unbuffered
if (closed_ && buf_.empty()) return recv_return(false);
// If there is a sender, directly receive the value we want
// from the sender. In case of a buffered channel, read from
// buffer and move front of send queue to the buffer
if (!sendq.empty()) {
std::shared_ptr<QueueMessage> m =
get_first_message(&sendq, ChannelAction::RECEIVE);
if (buf_.size() > 0) {
// Case 1 : Channel is Buffered
// Do Data transfer from front of buffer
// and add a QueueMessage to the buffer
*item = std::move(buf_.front());
buf_.pop_front();
// If first message from sendq is not null
// add it to the buffer and notify it
if (m != nullptr) {
// Copy to buffer
buf_.push_back(std::move(*(m->data)));
m->Notify();
} // Ignore if there is no first message
} else {
// Case 2: Channel is Unbuffered
// Do data transfer from front of SendQ
// If front is nullptr, then recursively call itself
if (m != nullptr) {
*item = std::move(*(m->data));
m->Notify();
} else {
return recv_return(Receive(item));
}
}
return recv_return(true);
}
// If this is a buffered channel and there are items in buffer
if (buf_.size() > 0) {
// Directly read from buffer
*item = std::move(buf_.front());
buf_.pop_front();
// return true
return recv_return(true);
}
// No sender available, block on this channel
// Some receiver will complete the option for us
auto m = std::make_shared<QueueMessage>(item);
recvq.push_back(m);
m->Wait(lock);
return recv_return(!m->chan_closed);
}
template <typename T>
void ChannelImpl<T>::Lock() {
mu_.lock();
}
template <typename T>
void ChannelImpl<T>::Unlock() {
mu_.unlock();
}
template <typename T>
bool ChannelImpl<T>::IsClosed() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return closed_;
}
template <typename T>
void ChannelImpl<T>::Close() {
std::unique_lock<std::recursive_mutex> lock{mu_};
if (closed_) {
// TODO(abhinavarora): closing an already closed channel should panic
lock.unlock();
return;
}
closed_ = true;
// Empty the readers
while (!recvq.empty()) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
m->chan_closed = true;
// Execute callback function (if any)
if (m->callback != nullptr) {
m->callback(ChannelAction::CLOSE);
}
m->Notify();
}
// Empty the senders
while (!sendq.empty()) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
m->chan_closed = true;
// Execute callback function (if any)
if (m->callback != nullptr) {
m->callback(ChannelAction::CLOSE);
}
m->Notify();
}
}
template <typename T>
void ChannelImpl<T>::AddToSendQ(
const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
std::lock_guard<std::recursive_mutex> lock{mu_};
auto m = std::make_shared<QueueMessage>(data, cond);
m->referrer = referrer;
m->callback = cb;
sendq.push_back(m);
}
template <typename T>
void ChannelImpl<T>::AddToReceiveQ(
const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
std::lock_guard<std::recursive_mutex> lock{mu_};
auto m = std::make_shared<QueueMessage>(data, cond);
m->referrer = referrer;
m->callback = cb;
recvq.push_back(m);
}
template <typename T>
void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {
std::lock_guard<std::recursive_mutex> lock{mu_};
for (auto it = sendq.begin(); it != sendq.end();) {
std::shared_ptr<QueueMessage> sendMsg = (std::shared_ptr<QueueMessage>)*it;
if (sendMsg->referrer == referrer) {
it = sendq.erase(it);
} else {
++it;
}
}
}
template <typename T>
void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {
std::lock_guard<std::recursive_mutex> lock{mu_};
for (auto it = recvq.begin(); it != recvq.end();) {
std::shared_ptr<QueueMessage> recvMsg = (std::shared_ptr<QueueMessage>)*it;
if (recvMsg->referrer == referrer) {
it = recvq.erase(it);
} else {
++it;
}
}
}
template <typename T>
ChannelImpl<T>::~ChannelImpl() {
Close();
// The destructor must wait for all readers and writers to complete their task
// The channel has been closed, so we will not accept new readers and writers
std::unique_lock<std::recursive_mutex> lock{mu_};
destructor_cond_.wait(lock,
[this]() { return send_ctr == 0 && recv_ctr == 0; });
}
} // namespace framework
} // namespace paddle
此差异已折叠。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
USE_NO_KERNEL_OP(go);
USE_NO_KERNEL_OP(channel_close);
USE_NO_KERNEL_OP(channel_create);
USE_NO_KERNEL_OP(channel_recv);
USE_NO_KERNEL_OP(channel_send);
USE_NO_KERNEL_OP(elementwise_add);
USE_NO_KERNEL_OP(select);
USE_NO_KERNEL_OP(conditional_block);
USE_NO_KERNEL_OP(equal);
USE_NO_KERNEL_OP(assign);
USE_NO_KERNEL_OP(while);
USE_NO_KERNEL_OP(print);
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace paddle {
namespace framework {
template <typename T>
LoDTensor *CreateVariable(Scope *scope, const p::CPUPlace &place,
std::string name, T value) {
// Create LoDTensor<int> of dim [1]
auto var = scope->Var(name);
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize({1});
T *expect = tensor->mutable_data<T>(place);
expect[0] = value;
return tensor;
}
void AddOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, AttributeMap attrs,
BlockDesc *block) {
// insert op
auto op = block->AppendOp();
op->SetType(type);
for (auto &kv : inputs) {
op->SetInput(kv.first, kv.second);
}
for (auto &kv : outputs) {
op->SetOutput(kv.first, kv.second);
}
op->SetAttrMap(attrs);
}
void AddCase(ProgramDesc *program, Scope *scope, p::CPUPlace *place,
BlockDesc *casesBlock, int caseId, int caseType,
std::string caseChannel, std::string caseVarName,
std::function<void(BlockDesc *, Scope *)> func) {
std::string caseCondName = std::string("caseCond") + std::to_string(caseId);
std::string caseCondXVarName =
std::string("caseCondX") + std::to_string(caseId);
BlockDesc *caseBlock = program->AppendBlock(*casesBlock);
func(caseBlock, scope);
CreateVariable(scope, *place, caseCondName, false);
CreateVariable(scope, *place, caseCondXVarName, caseId);
CreateVariable(scope, *place, caseVarName, caseId);
scope->Var("step_scope");
AddOp("equal", {{"X", {caseCondXVarName}}, {"Y", {"caseToExecute"}}},
{{"Out", {caseCondName}}}, {}, casesBlock);
AddOp("conditional_block", {{"X", {caseCondName}}, {"Params", {}}},
{{"Out", {}}, {"Scope", {"step_scope"}}},
{{"sub_block", caseBlock}, {"is_scalar_condition", true}}, casesBlock);
}
void AddFibonacciSelect(Scope *scope, p::CPUPlace *place, ProgramDesc *program,
BlockDesc *parentBlock, std::string dataChanName,
std::string quitChanName) {
BlockDesc *whileBlock = program->AppendBlock(*parentBlock);
CreateVariable(scope, *place, "whileExitCond", true);
CreateVariable(scope, *place, "caseToExecute", -1);
CreateVariable(scope, *place, "case1var", 0);
CreateVariable(scope, *place, "xtemp", 0);
// TODO(thuan): Need to create fibXToSend, since channel send moves the actual
// data,
// which causes the data to be no longer accessible to do the fib calculation
// TODO(abhinav): Change channel send to do a copy instead of a move!
CreateVariable(scope, *place, "fibXToSend", 0);
CreateVariable(scope, *place, "fibX", 0);
CreateVariable(scope, *place, "fibY", 1);
CreateVariable(scope, *place, "quitVar", 0);
BlockDesc *casesBlock = program->AppendBlock(*whileBlock);
std::function<void(BlockDesc * caseBlock)> f = [](BlockDesc *caseBlock) {};
// TODO(thuan): Remove this once we change channel send to do a copy instead
// of move
AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"fibXToSend"}}}, {}, whileBlock);
// Case 0: Send to dataChanName
std::function<void(BlockDesc * caseBlock, Scope * scope)> case0Func = [&](
BlockDesc *caseBlock, Scope *scope) {
AddOp("assign", {{"X", {"fibX"}}}, {{"Out", {"xtemp"}}}, {}, caseBlock);
AddOp("assign", {{"X", {"fibY"}}}, {{"Out", {"fibX"}}}, {}, caseBlock);
AddOp("elementwise_add", {{"X", {"xtemp"}}, {"Y", {"fibY"}}},
{{"Out", {"fibY"}}}, {}, caseBlock);
};
AddCase(program, scope, place, casesBlock, 0, 1, dataChanName, "fibXToSend",
case0Func);
std::string case0Config =
std::string("0,1,") + dataChanName + std::string(",fibXToSend");
// Case 1: Receive from quitChanName
std::function<void(BlockDesc * caseBlock, Scope * scope)> case2Func = [&](
BlockDesc *caseBlock, Scope *scope) {
// Exit the while loop after we receive from quit channel.
// We assign a false to "whileExitCond" variable, which will
// break out of while_op loop
CreateVariable(scope, *place, "whileFalse", false);
AddOp("assign", {{"X", {"whileFalse"}}}, {{"Out", {"whileExitCond"}}}, {},
caseBlock);
};
AddCase(program, scope, place, casesBlock, 1, 2, quitChanName, "quitVar",
case2Func);
std::string case1Config =
std::string("1,2,") + quitChanName + std::string(",quitVar");
// Select block
AddOp("select", {{"X", {dataChanName, quitChanName}},
{"case_to_execute", {"caseToExecute"}}},
{{"Out", {}}},
{{"sub_block", casesBlock},
{"cases", std::vector<std::string>{case0Config, case1Config}}},
whileBlock);
scope->Var("stepScopes");
AddOp("while",
{{"X", {dataChanName, quitChanName}}, {"Condition", {"whileExitCond"}}},
{{"Out", {}}, {"StepScopes", {"stepScopes"}}},
{{"sub_block", whileBlock}}, parentBlock);
}
TEST(Concurrency, Go_Op) {
Scope scope;
p::CPUPlace place;
// Initialize scope variables
p::CPUDeviceContext ctx(place);
// Create channel variable
scope.Var("Channel");
// Create Variables, x0 will be put into channel,
// result will be pulled from channel
CreateVariable(&scope, place, "Status", false);
CreateVariable(&scope, place, "x0", 99);
CreateVariable(&scope, place, "result", 0);
framework::Executor executor(place);
ProgramDesc program;
BlockDesc *block = program.MutableBlock(0);
// Create channel OP
AddOp("channel_create", {}, {{"Out", {"Channel"}}},
{{"capacity", 10}, {"data_type", f::proto::VarType::LOD_TENSOR}},
block);
// Create Go Op routine
BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
AddOp("channel_send", {{"Channel", {"Channel"}}, {"X", {"x0"}}},
{{"Status", {"Status"}}}, {}, goOpBlock);
// Create Go Op
AddOp("go", {{"X", {"Channel", "x0"}}}, {}, {{"sub_block", goOpBlock}},
block);
// Create Channel Receive Op
AddOp("channel_recv", {{"Channel", {"Channel"}}},
{{"Status", {"Status"}}, {"Out", {"result"}}}, {}, block);
// Create Channel Close Op
AddOp("channel_close", {{"Channel", {"Channel"}}}, {}, {}, block);
// Check the result tensor to make sure it is set to 0
const LoDTensor &tensor = (scope.FindVar("result"))->Get<LoDTensor>();
auto *initialData = tensor.data<int>();
EXPECT_EQ(initialData[0], 0);
executor.Run(program, &scope, 0, true, true);
// After we call executor.run, the Go operator should do a channel_send to
// set the "result" variable to 99.
auto *finalData = tensor.data<int>();
EXPECT_EQ(finalData[0], 99);
}
/**
* This test implements the fibonacci function using go_op and select_op
*/
TEST(Concurrency, Select) {
Scope scope;
p::CPUPlace place;
// Initialize scope variables
p::CPUDeviceContext ctx(place);
CreateVariable(&scope, place, "Status", false);
CreateVariable(&scope, place, "result", 0);
CreateVariable(&scope, place, "currentXFib", 0);
framework::Executor executor(place);
ProgramDesc program;
BlockDesc *block = program.MutableBlock(0);
// Create channel OP
std::string dataChanName = "Channel";
scope.Var(dataChanName);
AddOp("channel_create", {}, {{"Out", {dataChanName}}},
{{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);
std::string quitChanName = "Quit";
scope.Var(quitChanName);
AddOp("channel_create", {}, {{"Out", {quitChanName}}},
{{"capacity", 0}, {"data_type", f::proto::VarType::LOD_TENSOR}}, block);
// Create Go Op routine, which loops 10 times over fibonacci sequence
CreateVariable(&scope, place, "xReceiveVar", 0);
BlockDesc *goOpBlock = program.AppendBlock(program.Block(0));
for (int i = 0; i < 10; ++i) {
AddOp("channel_recv", {{"Channel", {dataChanName}}},
{{"Status", {"Status"}}, {"Out", {"currentXFib"}}}, {}, goOpBlock);
AddOp("print", {{"In", {"currentXFib"}}}, {{"Out", {"currentXFib"}}},
{{"first_n", 100},
{"summarize", -1},
{"print_tensor_name", false},
{"print_tensor_type", true},
{"print_tensor_shape", false},
{"print_tensor_lod", false},
{"print_phase", std::string("FORWARD")},
{"message", std::string("X: ")}},
goOpBlock);
}
CreateVariable(&scope, place, "quitSignal", 0);
AddOp("channel_send", {{"Channel", {quitChanName}}, {"X", {"quitSignal"}}},
{{"Status", {"Status"}}}, {}, goOpBlock);
// Create Go Op
AddOp("go", {{"X", {dataChanName, quitChanName}}}, {},
{{"sub_block", goOpBlock}}, block);
AddFibonacciSelect(&scope, &place, &program, block, dataChanName,
quitChanName);
// Create Channel Close Op
AddOp("channel_close", {{"Channel", {dataChanName}}}, {}, {}, block);
AddOp("channel_close", {{"Channel", {quitChanName}}}, {}, {}, block);
executor.Run(program, &scope, 0, true, true);
// After we call executor.run, "result" variable should be equal to 34
// (which is 10 loops through fibonacci sequence)
const LoDTensor &tensor = (scope.FindVar("currentXFib"))->Get<LoDTensor>();
auto *finalData = tensor.data<int>();
EXPECT_EQ(finalData[0], 34);
}
} // namespace framework
} // namespace paddle
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <typeindex>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......
......@@ -64,7 +64,8 @@ class OpHandleBase {
virtual bool IsMultiDeviceTransfer() { return false; }
const platform::DeviceContext *DeviceContext(platform::Place place) {
return dev_ctxes_[place];
auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
}
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
......
......@@ -80,15 +80,15 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// This is weird but there is really some variables without var_desc
// in computation_op
if (var_desc == nullptr) {
if (compute_op->Node()->Op()->Block()->FindVar(var_name) == nullptr)
continue;
} else {
if (var_desc->Persistable()) continue;
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS) {
continue;
}
var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name);
if (var_desc == nullptr) continue;
}
if (var_desc->Persistable()) continue;
auto var_type = var_desc->Proto()->type().type();
if (var_type != proto::VarType::LOD_TENSOR &&
var_type != proto::VarType::SELECTED_ROWS) {
continue;
}
// compute op only runs in one device
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
......@@ -47,6 +46,41 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
}
template <typename RefCntMap>
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
GarbageCollector<Tensor>* gc,
RefCntMap* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors;
auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue;
if ((it->second)-- == 1) {
auto* var = scope.FindVar(name);
if (var != nullptr) {
VLOG(10) << "Erase tensor \'" << name << "\'";
if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
if (!erase_tensors.empty()) {
gc->Add(erase_tensors);
}
}
Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() {
......@@ -67,7 +101,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
var->GetMutable<std::vector<framework::Scope*>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
......@@ -76,15 +110,13 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
"LOD_RANK_TABLE, PLACE_LIST, READER, RAW]",
var_type);
}
}
......@@ -334,9 +366,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) {
// WhileOp would set keep_kids to false
// WhileGradOp would need the scopes created in WhileOp
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) {
......@@ -355,45 +391,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
op->Run(*local_scope, place_);
if (gc != nullptr) {
std::vector<std::string> erase_vars;
for (auto& input : op->Inputs()) {
for (auto& input_name : input.second) {
auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->cur_ref_cnts_.erase(input_name);
} else {
--(it->second);
}
}
}
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->cur_ref_cnts_.erase(output_name);
} else {
--(it->second);
}
}
}
if (!erase_vars.empty()) {
std::vector<framework::LoDTensor*> erase_tensors;
for (auto& name : erase_vars) {
auto* var = local_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
erase_tensors.push_back(tensor);
}
}
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
&(ctx->cur_ref_cnts_));
}
if (FLAGS_benchmark) {
......
......@@ -32,38 +32,32 @@ template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id);
std::unordered_set<std::string> ignored_vars;
std::unordered_map<std::string, T> ref_cnts;
for (auto var_desc : block.AllVars()) {
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) {
ignored_vars.insert(var_desc->Name()); // ignore persistable vars
}
}
for (auto op_desc : block.AllOps()) {
for (auto& input : op_desc->Inputs()) {
for (auto& input_name : input.second) {
if (!ignored_vars.count(input_name)) {
if (ref_cnts.count(input_name))
++ref_cnts[input_name];
else
ref_cnts[input_name] = 1;
auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto* var_desc = block.FindVar(name);
if (var_desc == nullptr || var_desc->Persistable()) continue;
auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
type != proto::VarType::SELECTED_ROWS) {
continue;
}
}
}
for (auto& output : op_desc->Outputs()) {
for (auto output_name : output.second) {
if (!ignored_vars.count(output_name)) {
if (ref_cnts.count(output_name))
++ref_cnts[output_name];
else
ref_cnts[output_name] = 1;
auto it = ref_cnts.find(name);
if (it != ref_cnts.end()) {
++it->second;
} else {
ref_cnts[name] = 1;
}
}
}
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
}
return ref_cnts;
}
......
......@@ -27,8 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
// be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
if (index >= feed_inputs.size()) {
feed_inputs.resize(index + 1);
}
......
......@@ -126,7 +126,6 @@ message VarType {
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
CHANNEL = 16;
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
......@@ -158,12 +157,6 @@ message VarType {
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
message ChannelDesc {
required Type data_type = 1;
required int64 capacity = 2;
}
optional ChannelDesc channel = 6;
message Tuple { repeated Type element_type = 1; }
optional Tuple tuple = 7;
}
......
set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
......@@ -9,7 +10,7 @@ function(pass_library TARGET DEST)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass ${op_library_DEPS})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
message(STATUS "add pass ${TARGET} ${DEST}")
......@@ -24,18 +25,23 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(fc_fuse_pass inference)
if (WITH_MKLDNN)
pass_library(conv_relu_mkldnn_fuse_pass inference)
endif ()
pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(conv_bn_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
......
......@@ -262,7 +262,7 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unordered_set<std::string> specified_vars({"data_lod_attention",
"cell_init", "hidden_init",
"data", "week", "minute"});
int count = 0;
size_t count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsVar() && specified_vars.count(node->Name())) {
++count;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <functional>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
LoDTensor vec_y;
vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>();
const float* b = vec_b.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < vec_a.numel(); i++) {
y[i] = f(a[i], b[i]);
}
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input);
int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBias fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_bias_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
// bias
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
// output
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
// elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
PADDLE_ENFORCE(subgraph.count(conv_input));
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
VLOG(3) << "do not perform conv+bias fuse";
return;
}
auto* eltwise_bias_tensor =
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
*conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
// take eltwise bias as conv bias
OpDesc desc;
desc.SetInput(
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
}
found_conv_bias_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_bias_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
*/
class ConvBiasFusePass : public FusePassBase {
public:
virtual ~ConvBiasFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
#include <functional>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_CONV_BN_NODES(pattern_name) \
/* OPERATORS */ \
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \
/* CONV inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
/* CONV outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
/* BN inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \
/* BN outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
void recompute_bias_and_weights(const Scope* scope,
ir::Node* conv_weight, //
const ir::Node& bn_scale, //
const LoDTensor& bn_bias_tensor, //
const ir::Node& bn_mean, //
const ir::Node& bn_variance, //
LoDTensor* eltwise_y_in_tensor, //
float epsilon) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixArrayMap = Eigen::Map<
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from BN
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims());
auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
auto* variance_tensor =
scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();
ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
scale_tensor->numel(), 1);
EigenVectorArrayMap variance_array(
variance_tensor->mutable_data<float>(platform::CPUPlace()),
variance_tensor->numel(), 1);
ConstEigenVectorArrayMap mean_array(mean_tensor->data<float>(),
mean_tensor->numel(), 1);
ConstEigenVectorArrayMap bn_bias_array(bn_bias_tensor.data<float>(),
bn_bias_tensor.numel(), 1);
// variance will not be used anymore, so make it std_array and then tmp_array
variance_array += epsilon;
variance_array = variance_array.sqrt();
variance_array = scale_array / variance_array;
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
eltwise_y_in_array =
((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
// Re-compute weight of conv2d from BN
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= variance_array;
}
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
conv_bn_pattern(conv_input, false /*with_eltwise_add*/);
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
// bn_saved_variance
GET_CONV_BN_NODES(conv_bn_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+bn fuse";
return;
}
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
eltwise_y_in_desc.SetPersistable(true);
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
// Get batch norm bias
auto* bn_bias_tensor =
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
// Initialize eltwise_y
eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 0.0f);
// update weights and biases
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon);
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) {
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(),
"Bias") != input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
// reuse existing conv bias node
auto conv_bias_names = conv->Op()->Input("Bias");
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
eltwise_y_in_tensor->dims());
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
} else {
// add new conv_bias node
conv->Op()->SetInput(
"Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
IR_NODE_LINK_TO(eltwise_y_in_node, conv);
}
conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv, bn_out);
found_conv_bn_count++;
} else { // fuse_option == FUSE_NATIVE
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
IR_NODE_LINK_TO(eltwise_op, bn_out);
found_conv_bn_count++;
}
};
gpd(graph.get(), handler);
AddStatis(found_conv_bn_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
conv_bn_pattern(conv_input, true /*with_eltwise_add*/);
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle ConvBN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
GET_CONV_BN_NODES(conv_bn_pattern);
// OPERATORS
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
// BIAS inputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
// BIAS outputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);
// Get eltwise_y (conv bias) variable
auto* eltwise_y_in_tensor =
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
// Get batch norm bias
auto* bn_bias_tensor =
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
// update weights and biases
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon);
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
IR_NODE_LINK_TO(eltwise, bn_out);
found_conv_bn_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_bn_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvEltwiseAddBNFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp.
*/
class ConvBNFusePass : public FusePassBase {
public:
virtual ~ConvBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bn_fuse"};
};
class ConvEltwiseAddBNFusePass : public FusePassBase {
public:
virtual ~ConvEltwiseAddBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -46,6 +46,12 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
FuseOptions fuse_option = FindFuseOption(*conv, *relu);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+relu fuse";
return;
}
// Transform Conv node into ConvReLU node.
OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
......
......@@ -20,17 +20,19 @@ namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
......@@ -43,7 +45,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) {
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
......@@ -51,14 +54,24 @@ ProgramDesc BuildProgramDesc() {
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"b"}),
SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}));
SetOp(&prog, "relu", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}));
// conv+relu, both with MKL-DNN
SetOp(&prog, "conv2d", "conv1",
std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+relu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog;
}
......@@ -88,10 +101,16 @@ TEST(ConvReLUFusePass, basic) {
auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
// check if only "conv1" convolution is fused
auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv1") {
ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_relu"));
}
}
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace framework {
namespace ir {
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Build pattern
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
->assert_is_op_input("lookup_table")
->assert_var_not_persistable();
patterns::Embedding embedding_pattern(pattern, name_scope);
// TODO(jczaja): Intermediate can only be for val that are not used anywhere
// but lookup table output may go into other LSTM (for reverse
// direction)
auto* embedding_out = embedding_pattern(x);
patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out);
// Create New OpDesc
auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm,
Node* input, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* cell,
Node* xx, Node* fc_bias) {
OpDesc op_desc;
op_desc.SetType("fused_embedding_fc_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN(Ids, input);
SET_IN(WeightH, weight_h);
// Neet to have this passed as We need Wc data for peephole connections
SET_IN(Bias, bias);
#undef SET_IN
// Multiply embeddings with Weights
PADDLE_ENFORCE(scope);
const std::string& embeddings = patterns::UniqueKey("Embeddings");
auto* embeddings_var = scope->Var(embeddings);
PADDLE_ENFORCE(embeddings_var);
auto* embeddings_tensor =
embeddings_var->GetMutable<framework::LoDTensor>();
// Get WeightX size: [single_embedding, fc_size]
// and embedding size: [dict_size, single_embedding]
// and create new size of embeddings eg. [dict_size , hidden_size]
auto* embedding_var = scope->FindVar(W->Name());
PADDLE_ENFORCE(embedding_var);
const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
const auto& weightx_tensor =
scope->FindVar(weight_x->Name())->Get<framework::LoDTensor>();
embeddings_tensor->Resize(
{embedding_tensor.dims()[0], weightx_tensor.dims()[1]});
// Multiplie embeddings via WeightsX and add bias
auto embedding_data = embedding_tensor.data<float>();
auto weightx_data = weightx_tensor.data<float>();
auto embeddings_data =
embeddings_tensor->mutable_data<float>(platform::CPUPlace());
// Adding biases to GEMM result to be
auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var);
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
auto alpha = 1.0f;
auto beta = 1.0f;
int m = embedding_tensor.dims()[0];
int n = weightx_tensor.dims()[1];
int k = embedding_tensor.dims()[1];
// Copy only gate biases values (only actual bias data, not peephole
// weights)
std::vector<float> combined_biases;
combined_biases.reserve(n);
std::copy_n(lstm_bias_tensor.data<float>(), n,
std::back_inserter(combined_biases));
if (with_fc_bias) {
// Add FC-bias with LSTM-bias (into GEMM result to be)
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
for (int i = 0; i < fc_bias_tensor.numel(); i++) {
combined_biases[i] += fc_bias_tensor.data<float>()[i];
}
}
// broadcast biases
std::vector<float> ones(m, 1.0f);
paddle::operators::math::CBlas<float>::GEMM(
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
&combined_biases[0], n, 0.0f, embeddings_data, n);
// Wx*embeddings + biases
paddle::operators::math::CBlas<float>::GEMM(
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
embedding_data, k, weightx_data, n, beta, embeddings_data, n);
op_desc.SetInput("Embeddings", {embeddings});
// Create temp variables.
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
const std::string BatchedCellPreAct =
patterns::UniqueKey("BatchedCellPreAct");
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetOutput("Cell", {cell->Name()});
op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0);
OP_SET_OUT(ReorderedC0);
#undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input, op);
IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden);
return op;
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern);
GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
// TODO(jczaja): Add support for is_sparse / is_distributed
auto is_sparse = boost::get<bool>(lookup_table->Op()->GetAttr("is_sparse"));
auto is_distributed =
boost::get<bool>(lookup_table->Op()->GetAttr("is_distributed"));
if (is_sparse == true || is_distributed == true) {
return;
}
if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
Bias, Hidden, Cell, fc_out, fc_bias);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
std::unordered_set<const Node*> marked_nodes(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
Bias, Hidden, Cell, fc_out, nullptr);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
// std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
// lstm});
std::unordered_set<const Node*> marked_nodes({mul, lstm});
GraphSafeRemoveNodes(graph, marked_nodes);
}
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(embedding_fc_lstm_fuse_pass,
paddle::framework::ir::EmbeddingFCLSTMFusePass);
......@@ -11,29 +11,30 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono> // NOLINT
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace inference {
namespace framework {
namespace ir {
// Fusing of Embedding , FC and LSTM op
// Timer for timer
class Timer {
// Just FC without bias
class EmbeddingFCLSTMFusePass : public FusePassBase {
public:
std::chrono::high_resolution_clock::time_point start;
std::chrono::high_resolution_clock::time_point startu;
void tic() { start = std::chrono::high_resolution_clock::now(); }
double toc() {
startu = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_span =
std::chrono::duration_cast<std::chrono::duration<double>>(startu -
start);
double used_time_ms = static_cast<double>(time_span.count()) * 1000.0;
return used_time_ms;
}
virtual ~EmbeddingFCLSTMFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"embedding_fc_lstm_fuse"};
};
} // namespace inference
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
void FusePassBase::Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
Scope* FusePassBase::param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
void FusePassBase::AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
FuseOptions FusePassBase::FindFuseOption(const Node& node1,
const Node& node2) const {
#ifdef PADDLE_WITH_MKLDNN
bool node1_mkldnn = node1.Op()->HasAttr("use_mkldnn") &&
boost::get<bool>(node1.Op()->GetAttr("use_mkldnn"));
bool node2_mkldnn = node2.Op()->HasAttr("use_mkldnn") &&
boost::get<bool>(node2.Op()->GetAttr("use_mkldnn"));
if (node1_mkldnn && node2_mkldnn)
return FUSE_MKLDNN;
else if (!node1_mkldnn && !node2_mkldnn)
return FUSE_NATIVE;
else
return DO_NOT_FUSE;
#else
return FUSE_NATIVE;
#endif
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -25,32 +25,24 @@ namespace ir {
static const char kParamScopeAttr[] = "__param_scope__";
static const char kFuseStatisAttr[] = "__fuse_statis__";
enum FuseOptions {
DO_NOT_FUSE, // fusing will not be done
FUSE_NATIVE, // fusing will be done without MKL-DNN
FUSE_MKLDNN // fusing will be done with MKL-DNN
};
class FusePassBase : public Pass {
public:
void Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
void AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
void Init(const std::string& repr, Graph* graph) const;
Scope* param_scope() const;
void AddStatis(int count_of_fused) const;
virtual ~FusePassBase() {}
protected:
virtual FuseOptions FindFuseOption(const Node& node1,
const Node& node2) const;
mutable Graph* graph_;
mutable std::string repr_;
};
......
......@@ -259,6 +259,8 @@ GraphPatternDetector::DetectPatterns() {
return result;
}
// TODO(Superjomn) enhance the function as it marks unique unique as duplicates
// see https://github.com/PaddlePaddle/Paddle/issues/13550
void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetector::subgraph_t> *subgraphs) {
if (subgraphs->empty()) return;
......@@ -626,6 +628,112 @@ bool VarLinksFromOp(Node *node, const std::string &op_type) {
return false;
}
PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input,
bool with_eltwise_add) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
PDNode *eltwise_op = nullptr;
if (with_eltwise_add) {
eltwise_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
}
auto *batch_norm_op =
pattern->NewNode(batch_norm_repr())->assert_is_op("batch_norm");
// Create variables
// Conv Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d");
PDNode *eltwise_y_in_var = nullptr;
PDNode *eltwise_out_var = nullptr;
if (with_eltwise_add) {
// Conv output as Bias input
conv_out_var->assert_is_op_input("elementwise_add", "X");
// Bias
eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr())
->assert_is_op_input("elementwise_add", "Y")
->AsInput();
eltwise_out_var = pattern->NewNode(eltwise_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("elementwise_add");
} else {
// Conv output as BN input
conv_out_var->assert_is_op_input("batch_norm", "X");
}
// BN Scale
auto *bn_scale_var = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale");
// BN Bias
auto *bn_bias_var = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias");
// BN Mean
auto *bn_mean_var = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean");
// BN Variance
auto *bn_variance_var = pattern->NewNode(bn_variance_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance");
// BN output
auto *bn_out_var = pattern->NewNode(bn_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm");
auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "MeanOut");
auto *bn_variance_out_var =
pattern->NewNode(bn_variance_out_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "VarianceOut");
auto *bn_saved_mean_var =
pattern->NewNode(bn_saved_mean_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "SavedMean");
auto *bn_saved_variance_var =
pattern->NewNode(bn_saved_variance_repr())
->AsOutput()
->assert_is_op_output("batch_norm", "SavedVariance");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
if (with_eltwise_add) {
eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var})
.LinksTo({eltwise_out_var});
batch_norm_op
->LinksFrom({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
bn_variance_var})
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var,
bn_saved_mean_var, bn_saved_variance_var});
} else {
batch_norm_op
->LinksFrom({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var,
bn_variance_var})
.LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var,
bn_saved_mean_var, bn_saved_variance_var});
}
return bn_out_var;
}
PDNode *patterns::ConvReLU::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
......@@ -692,6 +800,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
}
}
PDNode *patterns::Embedding::operator()(PDNode *x) {
x->assert_is_op_input("lookup_table", "Ids");
auto *lookup_table_op =
pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table");
#define NEW_NODE(arg__, io__) \
auto *arg__ = pattern->NewNode(arg__##_repr()) \
->assert_is_op_##io__("lookup_table", #arg__);
NEW_NODE(W, input);
NEW_NODE(Out, output);
#undef NEW_NODE
lookup_table_op->LinksFrom({x, W});
lookup_table_op->LinksTo({Out});
return Out;
}
PDNode *patterns::LSTM::operator()(PDNode *x) {
x->assert_is_op_input("lstm", "Input");
auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
......@@ -840,6 +966,39 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad;
}
PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input) {
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("elementwise_add", "Y");
// output
auto *eltwise_out_var = pattern->NewNode(eltwise_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
eltiwse_op->LinksFrom({conv_out_var, eltwise_bias_var})
.LinksTo({eltwise_out_var});
return eltwise_out_var;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -375,6 +375,44 @@ struct PatternBase {
size_t id_;
};
// Conv with batch norm
// op: conv + (elementwise_add +) batch_norm
// named nodes:
// conv_weight, conv_out, conv,
// bn_x, bn_scale, bn_bias, bn_mean, bn_variance,
// bn_batch_norm, bn_y, bn_mean_out, bn_variance_out,
// bn_saved_mean, bn_saved_variance
struct ConvBN : public PatternBase {
ConvBN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bn") {}
PDNode* operator()(PDNode* conv_input, bool with_eltwise_add);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(batch_norm);
PATTERN_DECL_NODE(eltwise); // ELEMENTWISE_ADD
// CONV inputs
PATTERN_DECL_NODE(conv_weight); // Filter
// CONV outputs
PATTERN_DECL_NODE(conv_out); // tmp
// ELTWISE inputs
PATTERN_DECL_NODE(eltwise_y_in);
// ELTWISE outputs
PATTERN_DECL_NODE(eltwise_out); // tmp
// BN inputs
PATTERN_DECL_NODE(bn_scale);
PATTERN_DECL_NODE(bn_bias);
PATTERN_DECL_NODE(bn_mean);
PATTERN_DECL_NODE(bn_variance);
// BN outputs
PATTERN_DECL_NODE(bn_out); // Out
PATTERN_DECL_NODE(bn_mean_out);
PATTERN_DECL_NODE(bn_variance_out);
PATTERN_DECL_NODE(bn_saved_mean);
PATTERN_DECL_NODE(bn_saved_variance);
};
// CONV with ReLU
// op: conv + relu
// named nodes:
......@@ -418,6 +456,23 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE(Out);
};
// Embedding
struct Embedding : public PatternBase {
Embedding(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "embedding") {}
PDNode* operator()(PDNode* x);
// declare operator node's name
PATTERN_DECL_NODE(lookup_table);
// Inputs
//
PATTERN_DECL_NODE(Ids);
PATTERN_DECL_NODE(W); // embeddings
// Outputs
PATTERN_DECL_NODE(Out);
};
struct LSTM : public PatternBase {
LSTM(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {}
......@@ -523,6 +578,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
PATTERN_DECL_NODE(d_ele_y);
PATTERN_DECL_NODE(ele_y);
};
// Conv with Elementwise_add as bias
// op: conv + elementwise_add
// named nodes:
// conv_input, conv_weight,
// conv_out, conv,
// eltwise_bias, eltwise_out,
// elementwise_add
struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise);
// declare variable node's name
PATTERN_DECL_NODE(conv_weight);
PATTERN_DECL_NODE(conv_out);
PATTERN_DECL_NODE(eltwise_bias);
PATTERN_DECL_NODE(eltwise_out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy.";
for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) {
n->Op()->SetAttr("use_mkldnn", true);
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_placement_pass,
paddle::framework::ir::MKLDNNPlacementPass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/cpu_lstm_compute.h"
#pragma once
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace operators {
namespace math {} // namespace math
} // namespace operators
namespace framework {
namespace ir {
class MKLDNNPlacementPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,11 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/channel.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/string/pretty_log.h"
......@@ -35,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>();
var->GetMutable<std::vector<framework::Scope *>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
......@@ -44,8 +46,6 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
......@@ -146,5 +146,22 @@ void NaiveExecutor::CleanFeedFetchOps() {
ops_.swap(ops);
}
void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) {
#ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True";
for (size_t block_id = 0; block_id < program.Size(); ++block_id) {
auto *block = const_cast<ProgramDesc &>(program).MutableBlock(block_id);
for (auto *op : block->AllOps()) {
if (op->HasAttr("use_mkldnn")) {
op->SetAttr("use_mkldnn", true);
}
}
}
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
#endif
}
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,8 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
......@@ -46,6 +48,8 @@ class NaiveExecutor {
void CleanFeedFetchOps();
void EnableMKLDNN(const ProgramDesc& program);
protected:
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id);
......
......@@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs(
const std::string &name) const override;
void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
const std::string &input_n = Inputs(in)[i];
const std::string &output_n = Outputs(out)[j];
PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
in, i);
PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(input_n);
auto *out_var = block_.FindVarRecursive(output_n);
PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
"The type of %s and %s is not the same.", input_n, output_n);
SetDim(output_n, GetDim(input_n));
}
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
......@@ -64,10 +85,6 @@ class CompileTimeInferShapeContext : public InferShapeContext {
VLOG(3) << "input " << in << " is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
out_var->SetLoDLevel(in_var->GetLoDLevel());
}
......
......@@ -100,16 +100,6 @@ class OpDesc {
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
void SetInputMap(const VariableNameMap &input) {
this->inputs_ = input;
this->need_update_ = true;
}
void SetOutputMap(const VariableNameMap &output) {
this->outputs_ = output;
this->need_update_ = true;
}
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
......
......@@ -132,9 +132,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.")
.SetDefault("");
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate();
}
......
......@@ -46,7 +46,6 @@ class OpProtoAndCheckerMaker {
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
......@@ -14,17 +14,15 @@ limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "paddle/fluid/framework/operator.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -142,54 +140,27 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
try {
if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
}
}
// The profile has a process-wide mutex, results in serious performance issue
// in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern.
if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
} else {
RunImpl(scope, place);
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
}
VLOG(3) << place << " " << DebugStringEx(&scope);
}
bool OperatorBase::HasInputs(const std::string& name) const {
......@@ -217,7 +188,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
}
bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.end() != outputs_.find(name)) {
if (outputs_.find(name) != outputs_.end()) {
return true;
} else {
return false;
......@@ -579,13 +550,45 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name);
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
const std::string& input_n = Inputs(in)[i];
const std::string& output_n = Outputs(out)[j];
Variable* in_var = scope_.FindVar(input_n);
Variable* out_var = scope_.FindVar(output_n);
PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
"The type of %s and %s is not the same.", output_n,
GetDim(input_n));
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows.");
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
const std::vector<std::string>& inputs = Inputs(in);
const std::vector<std::string>& outputs = Outputs(out);
PADDLE_ENFORCE_LT(i, inputs.size());
PADDLE_ENFORCE_LT(j, outputs.size());
Variable* in_var = scope_.FindVar(inputs.at(i));
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = scope_.FindVar(outputs.at(j));
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
......@@ -613,20 +616,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout());
}
void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_layout(in_tensor.layout());
}
bool IsRuntime() const override { return true; }
protected:
......
......@@ -156,10 +156,12 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_);
#endif
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
if (VLOG_IS_ON(5)) {
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
"The number of graph should be only one");
}
}
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
......@@ -248,6 +250,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
ResetReferenceCount();
for (auto &pair : cur_ref_cnts_) {
auto &name_map = *(pair.second);
for (auto &fetch_name : fetch_tensors) {
name_map.erase(fetch_name);
}
name_map.erase(fetched_var_name);
}
}
#endif
auto fetch_data = member_->executor_->Run(fetch_tensors);
......@@ -298,6 +307,10 @@ ParallelExecutor::~ParallelExecutor() {
}
}
}
// member_ must be destructed before gcs_ since the destructor of
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_.reset();
}
} // namespace framework
......
......@@ -75,7 +75,7 @@ class ParallelExecutor {
private:
void BCastParamsToDevices(const std::unordered_set<std::string> &vars) const;
ParallelExecutorPrivate *member_;
std::unique_ptr<ParallelExecutorPrivate> member_;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
......
......@@ -126,7 +126,7 @@ const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {
std::vector<std::string> feed_target_names;
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
int col = boost::get<int>(op->GetAttr("col"));
size_t col = boost::get<int>(op->GetAttr("col"));
if (col >= feed_target_names.size()) {
feed_target_names.resize(col + 1);
}
......@@ -143,7 +143,7 @@ const std::vector<std::string> ProgramDesc::GetFetchTargetNames() {
std::vector<std::string> fetch_target_names;
for (auto *op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
int col = boost::get<int>(op->GetAttr("col"));
size_t col = boost::get<int>(op->GetAttr("col"));
if (col >= fetch_target_names.size()) {
fetch_target_names.resize(col + 1);
}
......
......@@ -39,7 +39,7 @@ TEST(READER, decorate_chain) {
{
auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0);
ASSERT_NE(endpoints.count(end_point1.get()), 0UL);
ASSERT_NE(endpoints.count(end_point2.get()), 0);
}
......
......@@ -46,6 +46,7 @@ struct RWLock {
private:
pthread_rwlock_t lock_;
};
// TODO(paddle-dev): Support RWLock for WIN32 for correctness.
#else
// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive
// In windows, rw_lock seems like a hack. Use empty object and do nothing.
......
......@@ -20,13 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h"
// The mutex is not needed by training and inference, only for distribution.
#if PADDLE_WITH_DISTRIBUTE
#define WITH_LOCK 1
#else
#define WITH_LOCK 0
#endif
DEFINE_bool(benchmark, false,
"Doing memory benchmark. It will make deleting scope synchronized, "
"and add some memory usage logs."
......@@ -56,24 +49,18 @@ int64_t GetEagerDeletionThreshold() {
Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this));
return *kids_.back();
}
Variable* Scope::Var(const std::string& name) {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
return VarInternal(name);
}
Variable* Scope::Var(std::string* name) {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) {
*name = new_name;
......@@ -82,39 +69,34 @@ Variable* Scope::Var(std::string* name) {
}
Variable* Scope::FindVar(const std::string& name) const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
return FindVarInternal(name);
}
Variable* Scope::FindLocalVar(const std::string& name) const {
std::lock_guard<std::mutex> lock(mutex_);
return FindVarLocally(name);
}
const Scope* Scope::FindScope(const Variable* var) const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
return FindScopeInternal(var);
}
void Scope::DropKids() {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
for (Scope* s : kids_) delete s;
kids_.clear();
}
bool Scope::HasKid(const Scope* scope) const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end();
}
std::vector<std::string> Scope::LocalVarNames() const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size());
for (auto& p : vars_) {
......@@ -124,9 +106,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}
void Scope::DeleteScope(Scope* scope) const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
......@@ -139,9 +119,7 @@ void Scope::DeleteScope(Scope* scope) const {
}
void Scope::EraseVars(const std::vector<std::string>& var_names) {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
......@@ -154,16 +132,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
RenameInternal(origin_name, new_name);
}
std::string Scope::Rename(const std::string& origin_name) const {
#if WITH_LOCK
std::unique_lock<std::mutex> lock(mutex_);
#endif
std::lock_guard<std::mutex> lock(mutex_);
auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name);
return new_name;
......
......@@ -63,6 +63,11 @@ class Scope {
/// Caller doesn't own the returned Variable.
Variable* FindVar(const std::string& name) const;
/// Find a variable in the current scope.
/// Return nullptr if cannot find.
/// Caller doesn't own the returned Variable.
Variable* FindLocalVar(const std::string& name) const;
const Scope* parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable.
......
......@@ -91,7 +91,7 @@ TEST(SelectedRows, SparseTable) {
ASSERT_TRUE(table.HasKey(10));
ASSERT_TRUE(table.HasKey(8));
ASSERT_TRUE(table.HasKey(6));
ASSERT_EQ(table.rows().size(), 3);
ASSERT_EQ(table.rows().size(), 3UL);
framework::Tensor ids;
ids.Resize(framework::make_ddim({4}));
......
......@@ -46,16 +46,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]);
}
void InferShapeContext::ShareLoDs(const std::string &in,
const std::string &out) const {
PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(),
"The number of arguments in %s and %s is not equal.", in,
out);
for (size_t i = 0; i < in.size(); ++i) {
ShareLoD(in, out, i, i);
}
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
......
......@@ -56,7 +56,8 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
void ShareLoDs(const std::string &in, const std::string &out) const;
virtual void ShareDim(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) = 0;
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
......
......@@ -36,6 +36,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
......@@ -71,6 +76,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
if (platform::is_same_place(src_place, dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
stream);
} else {
......@@ -114,6 +124,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
......@@ -130,6 +145,11 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr);
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
if (src_ptr == dst_ptr && platform::is_same_place(src_place, dst_place)) {
VLOG(3) << "Skip copy the same data from " << src_place << " to "
<< dst_place;
return;
}
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);
......@@ -165,10 +185,12 @@ inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
}
template <typename Predicate>
struct AnyVisitor : public boost::static_visitor<bool> {
class AnyVisitor : public boost::static_visitor<bool> {
private:
const framework::Tensor& tensor_;
Predicate predicate_;
public:
AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
: tensor_(tensor), predicate_(std::move(predicate)) {}
......@@ -206,6 +228,27 @@ struct AnyVisitor : public boost::static_visitor<bool> {
}
};
template <typename Predicate>
class AnyOutVisitor : public boost::static_visitor<> {
private:
const framework::Tensor& tensor_;
mutable framework::Tensor* out_;
Predicate predicate_;
public:
AnyOutVisitor(const framework::Tensor& tensor, Predicate predicate,
framework::Tensor* out)
: tensor_(tensor), out_(out), predicate_(std::move(predicate)) {}
template <typename Place>
void operator()(const Place& place) const {
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
out_->Resize({1});
out_->mutable_data<bool>(place);
AnyImpl(predicate_, tensor_, *ctx, out_);
}
};
template <typename Predicate>
inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
AnyVisitor<Predicate> visitor(tensor, predicate);
......@@ -213,6 +256,14 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
return platform::VisitPlace(place, visitor);
}
template <typename Predicate>
inline void Any(const framework::Tensor& tensor, Predicate predicate,
framework::Tensor* out) {
AnyOutVisitor<Predicate> visitor(tensor, predicate, out);
auto place = tensor.place();
platform::VisitPlace(place, visitor);
}
struct ContainsNANPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
......@@ -227,6 +278,12 @@ bool TensorContainsNAN(const framework::Tensor& tensor) {
return Any(tensor, predicate);
}
void TensorContainsNAN(const framework::Tensor& tensor,
framework::Tensor* out) {
ContainsNANPredicate predicate;
Any(tensor, predicate, out);
}
struct ContainsInfPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
......@@ -241,6 +298,71 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
return Any(tensor, predicate);
}
void TensorContainsInf(const framework::Tensor& tensor,
framework::Tensor* out) {
ContainsInfPredicate predicate;
Any(tensor, predicate, out);
}
// NOTE(dzhwinter):
// Isfinite need a AllVisitor to loop through all the elements.
// We choose two cuda call instead of one allvisitor. The AllVisitor
// should be implemented if the performance hurts.
bool TensorIsfinite(const framework::Tensor& tensor) {
ContainsInfPredicate pred_inf;
ContainsNANPredicate pred_nan;
return !Any(tensor, pred_inf) && !Any(tensor, pred_nan);
}
#ifdef PADDLE_WITH_CUDA
template <typename T>
static inline void __global__ BothFalse(const T* cmp, T* out) {
out[0] = (!cmp[0]) && (!out[0]);
}
#endif
struct BothFalseVisitor : public boost::static_visitor<> {
const framework::Tensor& in_;
mutable framework::Tensor* out_;
BothFalseVisitor(const framework::Tensor& in, framework::Tensor* out)
: in_(in), out_(out) {}
template <typename Place>
void operator()(const Place& place) const {
VisitorImpl(place);
}
void VisitorImpl(const platform::CUDAPlace& gpu) const {
#ifdef PADDLE_WITH_CUDA
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(gpu);
BothFalse<bool><<<1, 1, 0, ctx->stream()>>>(in_.data<bool>(),
out_->mutable_data<bool>(gpu));
#endif
}
void VisitorImpl(const platform::CPUPlace& cpu) const {
bool lhs = !in_.data<bool>()[0];
bool rhs = !out_->mutable_data<bool>(cpu)[0];
out_->mutable_data<bool>(cpu)[0] = lhs && rhs;
}
void VisitorImpl(
const platform::CUDAPinnedPlace& cpu /* equals to cpu*/) const {
bool lhs = !in_.data<bool>()[0];
bool rhs = !out_->mutable_data<bool>(cpu)[0];
out_->mutable_data<bool>(cpu)[0] = lhs && rhs;
}
};
void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) {
framework::Tensor tmp;
TensorContainsInf(tensor, &tmp);
TensorContainsNAN(tensor, out);
BothFalseVisitor visitor(tmp, out);
auto place = tensor.place();
platform::VisitPlace(place, visitor);
}
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version
......
......@@ -57,8 +57,15 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
template <typename T>
void TesnorToVector(const Tensor& src, std::vector<T>* dst);
// copy the result bool to cpu
bool TensorContainsNAN(const framework::Tensor& tensor);
bool TensorContainsInf(const framework::Tensor& tensor);
bool TensorIsfinite(const framework::Tensor& tensor);
// store the result bool in gpu tensor, async operation. Faster than above ones.
void TensorContainsNAN(const framework::Tensor& tensor, framework::Tensor* out);
void TensorContainsInf(const framework::Tensor& tensor, framework::Tensor* out);
void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out);
void TensorToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx);
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include <stdexcept>
#include <string>
#include <vector>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_desc.h"
......
......@@ -88,13 +88,7 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
}
void VarDesc::SetDataType(proto::VarType::Type data_type) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
mutable_channel_desc()->set_data_type(data_type);
break;
default:
mutable_tensor_desc()->set_data_type(data_type);
}
mutable_tensor_desc()->set_data_type(data_type);
}
void VarDesc::SetDataTypes(
......@@ -115,13 +109,7 @@ void VarDesc::SetDataTypes(
}
proto::VarType::Type VarDesc::GetDataType() const {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return channel_desc().data_type();
break;
default:
return tensor_desc().data_type();
}
return tensor_desc().data_type();
}
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
......@@ -134,17 +122,6 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return res;
}
void VarDesc::SetCapacity(int64_t capacity) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
break;
default:
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR:
......@@ -214,19 +191,6 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
}
}
const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.type().channel();
default:
PADDLE_THROW(
"Getting 'channel_desc' is not supported by the type of var %s.",
this->Name());
}
}
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
......@@ -262,20 +226,6 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
}
}
proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.mutable_type()->mutable_channel();
default:
PADDLE_THROW(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
......
此差异已折叠。
此差异已折叠。
......@@ -38,8 +38,12 @@ class Variable {
template <typename T>
T* GetMutable() {
if (!IsType<T>()) {
if (!holder_) {
holder_.reset(new PlaceholderImpl<T>(new T()));
} else {
PADDLE_ENFORCE(IsType<T>(),
"Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name());
}
return static_cast<T*>(holder_->Ptr());
}
......
......@@ -31,7 +31,6 @@ function(inference_api_test TARGET_NAME)
set(multiValueArgs ARGS)
cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
cc_test(${TARGET_NAME}
SRCS ${inference_test_SRC}
DEPS "${inference_deps}"
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册