提交 38d32c98 编写于 作者: Z Zeng Jinle 提交者: sneaxiy

merge develop

test=develop
...@@ -28,3 +28,4 @@ third_party/ ...@@ -28,3 +28,4 @@ third_party/
build_* build_*
# clion workspace. # clion workspace.
cmake-build-* cmake-build-*
model_test
...@@ -41,6 +41,7 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F ...@@ -41,6 +41,7 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
option(WITH_NGRAPH "Compile PaddlePaddle with nGraph support." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF) option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
...@@ -62,13 +63,12 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF) ...@@ -62,13 +63,12 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF) option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
option(WITH_FAST_BUNDLE_TEST "Bundle tests that can be run in a single process together to reduce launch overhead" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(WITH_INFERENCE "Compile fluid inference library" ON) option(ON_INFER "Turn on inference optimization." OFF)
option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF)
option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
...@@ -104,6 +104,8 @@ if(ANDROID OR IOS) ...@@ -104,6 +104,8 @@ if(ANDROID OR IOS)
"Disable RDMA when cross-compiling for Android and iOS" FORCE) "Disable RDMA when cross-compiling for Android and iOS" FORCE)
set(WITH_MKL OFF CACHE STRING set(WITH_MKL OFF CACHE STRING
"Disable MKL when cross-compiling for Android and iOS" FORCE) "Disable MKL when cross-compiling for Android and iOS" FORCE)
set(WITH_NGRAPH OFF CACHE STRING
"Disable nGraph when cross-compiling for Android and iOS" FORCE)
set(WITH_GOLANG OFF CACHE STRING set(WITH_GOLANG OFF CACHE STRING
"Disable golang when cross-compiling for Android and iOS" FORCE) "Disable golang when cross-compiling for Android and iOS" FORCE)
...@@ -127,6 +129,9 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING ...@@ -127,6 +129,9 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
"A path setting fluid shared and static libraries") "A path setting fluid shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries")
if (WITH_C_API AND WITH_PYTHON) if (WITH_C_API AND WITH_PYTHON)
message(WARNING "It is suggest not embedded a python interpreter in Paddle " message(WARNING "It is suggest not embedded a python interpreter in Paddle "
"when using C-API. It will give an unpredictable behavior when using a " "when using C-API. It will give an unpredictable behavior when using a "
...@@ -169,6 +174,7 @@ include(external/protobuf) # download, build, install protobuf ...@@ -169,6 +174,7 @@ include(external/protobuf) # download, build, install protobuf
include(external/python) # download, build, install python include(external/python) # download, build, install python
include(external/openblas) # download, build, install openblas include(external/openblas) # download, build, install openblas
include(external/mkldnn) # download, build, install mkldnn include(external/mkldnn) # download, build, install mkldnn
include(external/ngraph) # download, build, install nGraph
include(external/swig) # download, build, install swig include(external/swig) # download, build, install swig
include(external/boost) # download boost include(external/boost) # download boost
include(external/any) # download libn::any include(external/any) # download libn::any
...@@ -176,6 +182,7 @@ include(external/eigen) # download eigen3 ...@@ -176,6 +182,7 @@ include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11 include(external/pybind11) # download pybind11
include(external/cares) include(external/cares)
include(external/cub) include(external/cub)
include(external/xxhash) # download xxhash
if (NOT WIN32) if (NOT WIN32)
# there is no official support of snappystream, warpctc, nccl, cupti in windows # there is no official support of snappystream, warpctc, nccl, cupti in windows
...@@ -298,3 +305,11 @@ if(WITH_DOC) ...@@ -298,3 +305,11 @@ if(WITH_DOC)
find_python_module(recommonmark REQUIRED) find_python_module(recommonmark REQUIRED)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
if (ON_INFER)
message(STATUS "On inference mode, will take place some specific optimization.")
add_definitions(-DPADDLE_ON_INFERENCE)
else()
#TODO(luotao), combine this warning with `make inference_lib_dist` command.
message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.")
endif()
...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \ ...@@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \
pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 install -U docopt PyYAML sphinx==1.5.6 && \
pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \
easy_install -U pip && \ easy_install -U pip && \
pip install -U wheel && \ pip install -U pip setuptools wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark pip install sphinx-rtd-theme==0.1.9 recommonmark
RUN pip3 install pre-commit 'ipython==5.3.0' && \ RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip3 install opencv-python && \ pip3 install opencv-python && \
pip install pre-commit 'ipython==5.3.0' && \ pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python pip install opencv-python
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/index_en.html) [![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/docs/develop/documentation/zh/getstarted/index_cn.html) [![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
...@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. ...@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0) ### Latest PaddlePaddle Release: [Fluid 1.1.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.1)
### Install Latest Stable Release: ### Install Latest Stable Release:
``` ```
# Linux CPU # Linux CPU
...@@ -27,9 +27,9 @@ pip install paddlepaddle ...@@ -27,9 +27,9 @@ pip install paddlepaddle
# Linux GPU cuda9cudnn7 # Linux GPU cuda9cudnn7
pip install paddlepaddle-gpu pip install paddlepaddle-gpu
# Linux GPU cuda8cudnn7 # Linux GPU cuda8cudnn7
pip install paddlepaddle-gpu==0.15.0.post87 pip install paddlepaddle-gpu==1.1.0.post87
# Linux GPU cuda8cudnn5 # Linux GPU cuda8cudnn5
pip install paddlepaddle-gpu==0.15.0.post85 pip install paddlepaddle-gpu==1.1.0.post85
# For installation on other platform, refer to http://paddlepaddle.org/ # For installation on other platform, refer to http://paddlepaddle.org/
``` ```
...@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.15.0.post85 ...@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.15.0.post85
## Installation ## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website. It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) on our website.
## Documentation ## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and We provide [English](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation. [Chinese](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book) - [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook. You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html) - [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.1/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters. You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html) - [Python API](http://paddlepaddle.org/documentation/api/zh/1.1/fluid.html)
Our new API enables much shorter programs. Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html) - [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.1/advanced_usage/development/contribute_to_paddle.html)
We appreciate your contributions! We appreciate your contributions!
......
...@@ -142,5 +142,10 @@ def parse_args(): ...@@ -142,5 +142,10 @@ def parse_args():
choices=['reduce', 'all_reduce'], choices=['reduce', 'all_reduce'],
default='all_reduce', default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce') help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--fuse_broadcast_op',
action='store_true',
help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.'
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
else: else:
build_strategy.reduce_strategy = fluid.BuildStrategy( build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce ).ReduceStrategy.AllReduce
build_strategy.fuse_broadcast_op = args.fuse_broadcast_op
avg_loss = train_args[0] avg_loss = train_args[0]
...@@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
if args.use_fake_data or args.use_reader_op: if args.use_fake_data or args.use_reader_op:
try: try:
fetch_ret = exe.run(fetch_list) fetch_ret = exe.run(fetch_list)
except fluid.core.EOFException as eof: except fluid.core.EOFException as eof:
break break
......
文件模式从 100644 更改为 100755
...@@ -50,11 +50,7 @@ if(NOT WITH_PROFILER) ...@@ -50,11 +50,7 @@ if(NOT WITH_PROFILER)
endif(NOT WITH_PROFILER) endif(NOT WITH_PROFILER)
if(NOT CMAKE_CROSSCOMPILING) if(NOT CMAKE_CROSSCOMPILING)
if(WITH_AVX AND AVX512F_FOUND) if(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX512F_FLAG})
elseif(WITH_AVX AND AVX2_FOUND)
set(SIMD_FLAG ${AVX2_FLAG})
elseif(WITH_AVX AND AVX_FOUND)
set(SIMD_FLAG ${AVX_FLAG}) set(SIMD_FLAG ${AVX_FLAG})
elseif(SSE3_FOUND) elseif(SSE3_FOUND)
set(SIMD_FLAG ${SSE3_FLAG}) set(SIMD_FLAG ${SSE3_FLAG})
......
...@@ -37,7 +37,6 @@ SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) ...@@ -37,7 +37,6 @@ SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib")
INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers. INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) # For MKLDNN code to include internal headers.
INCLUDE_DIRECTORIES(${THIRD_PARTY_PATH}/install) # For Paddle code to include mkldnn.h
IF(${CBLAS_PROVIDER} STREQUAL "MKLML") IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
SET(MKLDNN_DEPENDS ${MKLML_PROJECT}) SET(MKLDNN_DEPENDS ${MKLML_PROJECT})
...@@ -45,7 +44,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML") ...@@ -45,7 +44,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE() ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN") MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF() ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result") SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value") SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}") SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
...@@ -54,7 +53,7 @@ ExternalProject_Add( ...@@ -54,7 +53,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944" GIT_TAG "21fb5f2af1dd14e132af4f1b79160977ee487818"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
add_library(ngraph INTERFACE)
IF(WIN32 OR APPLE)
MESSAGE(WARNING
"Windows or Mac is not supported with nGraph in Paddle yet."
"Force WITH_NGRAPH=OFF")
SET(WITH_NGRAPH OFF CACHE STRING "Disable nGraph in Windows and MacOS" FORCE)
ENDIF()
IF(${WITH_NGRAPH} AND NOT ${WITH_MKLDNN})
MESSAGE(WARNING
"nGraph needs mkl-dnn to be enabled."
"Force WITH_NGRAPH=OFF")
SET(WITH_NGRAPH OFF CACHE STRING "Disable nGraph if mkl-dnn is disabled" FORCE)
ENDIF()
IF(NOT ${WITH_NGRAPH})
return()
ENDIF()
INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph")
SET(NGRAPH_VERSION "0.9")
SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION})
SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so)
SET(NGRAPH_TBB_LIB_NAME libtbb.so.2)
SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git")
ExternalProject_Add(
${NGRAPH_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_PROJECT} ${MKLML_PROJECT}
GIT_REPOSITORY ${NGRAPH_GIT_REPO}
GIT_TAG ${NGRAPH_GIT_TAG}
PREFIX ${NGRAPH_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${NGRAPH_INSTALL_DIR}
CMAKE_ARGS -DNGRAPH_UNIT_TEST_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_TOOLS_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_INTERPRETER_ENABLE=FALSE
CMAKE_ARGS -DNGRAPH_DEX_ONLY=TRUE
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLDNN_INCLUDE_DIR=${MKLDNN_INC_DIR}
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib
)
if(UNIX AND NOT APPLE)
include(GNUInstallDirs)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
else()
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/lib)
endif()
MESSAGE(STATUS "nGraph lib will be installed at: ${NGRAPH_LIB_DIR}")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
# Workaround for nGraph expecting mklml to be in mkldnn install directory.
ExternalProject_Add_Step(
${NGRAPH_PROJECT}
PrepareMKL
COMMAND ${CMAKE_COMMAND} -E create_symlink ${MKLML_LIB} ${MKLDNN_INSTALL_DIR}/lib/libmklml_intel.so
COMMAND ${CMAKE_COMMAND} -E create_symlink ${MKLML_IOMP_LIB} ${MKLDNN_INSTALL_DIR}/lib/libiomp5.so
DEPENDEES download
DEPENDERS configure
)
add_dependencies(ngraph ${NGRAPH_PROJECT})
target_compile_definitions(ngraph INTERFACE -DPADDLE_WITH_NGRAPH)
target_include_directories(ngraph INTERFACE ${NGRAPH_INC_DIR})
target_link_libraries(ngraph INTERFACE ${NGRAPH_SHARED_LIB})
LIST(APPEND external_project_dependencies ngraph)
...@@ -30,66 +30,61 @@ UNSET_VAR(PROTOBUF_LITE_LIBRARY) ...@@ -30,66 +30,61 @@ UNSET_VAR(PROTOBUF_LITE_LIBRARY)
UNSET_VAR(PROTOBUF_LIBRARY) UNSET_VAR(PROTOBUF_LIBRARY)
UNSET_VAR(PROTOBUF_INCLUDE_DIR) UNSET_VAR(PROTOBUF_INCLUDE_DIR)
UNSET_VAR(Protobuf_PROTOC_EXECUTABLE) UNSET_VAR(Protobuf_PROTOC_EXECUTABLE)
function(protobuf_generate_python SRCS)
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
if(NOT ARGN)
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
return()
endif()
if(NOT COMMAND protobuf_generate_python) # before cmake 3.4, protobuf_genrerate_python is not defined. if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
function(protobuf_generate_python SRCS) # Create an include path for each file specified
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
if(NOT ARGN)
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
return()
endif()
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
# Create an include path for each file specified
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
else()
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
endif()
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
endif()
if(DEFINED Protobuf_IMPORT_DIRS)
foreach(DIR ${Protobuf_IMPORT_DIRS})
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
endif()
set(${SRCS})
foreach(FIL ${ARGN}) foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE) get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE) get_filename_component(ABS_PATH ${ABS_FIL} PATH)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH) list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
get_filename_component(FIL_DIR ${FIL} DIRECTORY) if(${_contains_already} EQUAL -1)
if(FIL_DIR) list(APPEND _protobuf_include_path -I ${ABS_PATH})
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif() endif()
endforeach()
else()
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
endif()
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
endif()
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py") if(DEFINED Protobuf_IMPORT_DIRS)
add_custom_command( foreach(DIR ${Protobuf_IMPORT_DIRS})
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py" get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL} list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
DEPENDS ${ABS_FIL} ${Protobuf_PROTOC_EXECUTABLE} if(${_contains_already} EQUAL -1)
COMMENT "Running Python protocol buffer compiler on ${FIL}" list(APPEND _protobuf_include_path -I ${ABS_PATH})
VERBATIM ) endif()
endforeach() endforeach()
endif()
set(${SRCS} ${${SRCS}} PARENT_SCOPE) set(${SRCS})
endfunction() foreach(FIL ${ARGN})
endif() get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif()
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py"
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${PROTOBUF_PROTOC_EXECUTABLE}
COMMENT "Running Python protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
# Print and set the protobuf library information, # Print and set the protobuf library information,
# finish this cmake process and exit from this file. # finish this cmake process and exit from this file.
...@@ -126,6 +121,7 @@ macro(PROMPT_PROTOBUF_LIB) ...@@ -126,6 +121,7 @@ macro(PROMPT_PROTOBUF_LIB)
# FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`. # FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`.
# make `protobuf_generate_cpp` happy. # make `protobuf_generate_cpp` happy.
SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE}) SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE})
FOREACH(dep ${protobuf_DEPS}) FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep}) ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep}) ADD_DEPENDENCIES(protobuf_lite ${dep})
......
INCLUDE(ExternalProject)
set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash)
set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash)
set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include")
IF(WITH_STATIC_LIB)
SET(BUILD_CMD make lib)
ELSE()
IF(APPLE)
SET(BUILD_CMD sed -i \"\" "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib)
ELSE(APPLE)
SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib)
ENDIF(APPLE)
ENDIF()
ExternalProject_Add(
extern_xxhash
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/Cyan4973/xxHash"
GIT_TAG "v0.6.5"
PREFIX ${XXHASH_SOURCE_DIR}
DOWNLOAD_NAME "xxhash"
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
PATCH_COMMAND
BUILD_COMMAND ${BUILD_CMD}
INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install
TEST_COMMAND ""
)
set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a")
INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR})
add_library(xxhash STATIC IMPORTED GLOBAL)
set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES})
include_directories(${XXHASH_INCLUDE_DIR})
add_dependencies(xxhash extern_xxhash)
LIST(APPEND external_project_dependencies xxhash)
IF(WITH_C_API)
INSTALL(DIRECTORY ${XXHASH_INCLUDE_DIR} DESTINATION third_party/xxhash)
IF(ANDROID)
INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib/${ANDROID_ABI})
ELSE()
INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib)
ENDIF()
ENDIF()
...@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME) ...@@ -261,6 +261,13 @@ function(cc_library TARGET_NAME)
add_dependencies(${TARGET_NAME} mklml) add_dependencies(${TARGET_NAME} mklml)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed") target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif() endif()
# remove link to python, see notes at:
# https://github.com/pybind/pybind11/blob/master/docs/compiling.rst#building-manually
if("${cc_library_DEPS};" MATCHES "python;")
list(REMOVE_ITEM cc_library_DEPS python)
add_dependencies(${TARGET_NAME} python)
target_link_libraries(${TARGET_NAME} "-Wl,-undefined,dynamic_lookup")
endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
endif() endif()
...@@ -311,6 +318,8 @@ function(cc_test TARGET_NAME) ...@@ -311,6 +318,8 @@ function(cc_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true) set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
endfunction(cc_test) endfunction(cc_test)
...@@ -629,6 +638,8 @@ function(py_test TARGET_NAME) ...@@ -629,6 +638,8 @@ function(py_test TARGET_NAME)
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS} PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
endif() endif()
endfunction() endfunction()
......
...@@ -18,7 +18,7 @@ function(copy TARGET) ...@@ -18,7 +18,7 @@ function(copy TARGET)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DSTS DEPS) set(multiValueArgs SRCS DSTS DEPS)
cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(inference_lib_dist_dep ${TARGET} ${inference_lib_dist_dep} PARENT_SCOPE) set(fluid_lib_dist_dep ${TARGET} ${fluid_lib_dist_dep} PARENT_SCOPE)
list(LENGTH copy_lib_SRCS copy_lib_SRCS_len) list(LENGTH copy_lib_SRCS copy_lib_SRCS_len)
list(LENGTH copy_lib_DSTS copy_lib_DSTS_len) list(LENGTH copy_lib_DSTS copy_lib_DSTS_len)
...@@ -31,7 +31,7 @@ function(copy TARGET) ...@@ -31,7 +31,7 @@ function(copy TARGET)
foreach(index RANGE ${len}) foreach(index RANGE ${len})
list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_SRCS ${index} src)
list(GET copy_lib_DSTS ${index} dst) list(GET copy_lib_DSTS ${index} dst)
add_custom_command(TARGET ${TARGET} PRE_BUILD add_custom_command(TARGET ${TARGET} PRE_BUILD
COMMAND mkdir -p "${dst}" COMMAND mkdir -p "${dst}"
COMMAND cp -r "${src}" "${dst}" COMMAND cp -r "${src}" "${dst}"
COMMENT "copying ${src} -> ${dst}") COMMENT "copying ${src} -> ${dst}")
...@@ -67,6 +67,13 @@ copy(boost_lib ...@@ -67,6 +67,13 @@ copy(boost_lib
DEPS boost DEPS boost
) )
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/xxhash")
copy(xxhash_lib
SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib
DEPS xxhash
)
if(NOT PROTOBUF_FOUND) if(NOT PROTOBUF_FOUND)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf")
copy(protobuf_lib copy(protobuf_lib
...@@ -150,16 +157,16 @@ if (WITH_ANAKIN AND WITH_MKL) ...@@ -150,16 +157,16 @@ if (WITH_ANAKIN AND WITH_MKL)
SRCS SRCS
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/libinference_anakin_api* # compiled anakin api ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/libinference_anakin_api* # compiled anakin api
${ANAKIN_INSTALL_DIR} # anakin release ${ANAKIN_INSTALL_DIR} # anakin release
DSTS ${dst_dir}/inference/anakin ${FLUID_INSTALL_DIR}/third_party/install/anakin) DSTS ${FLUID_INSTALL_DIR}/third_party/install/anakin ${FLUID_INSTALL_DIR}/third_party/install/anakin)
list(APPEND inference_deps anakin_inference_lib) list(APPEND inference_deps anakin_inference_lib)
endif() endif()
set(module "inference") set(module "inference")
copy(inference_lib DEPS ${inference_deps} copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${src_dir}/${module}/api/paddle_inference_api.h ${src_dir}/${module}/api/demo_ci ${src_dir}/${module}/api/paddle_inference_api.h
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
) )
set(module "platform") set(module "platform")
...@@ -185,20 +192,41 @@ copy(cmake_cache ...@@ -185,20 +192,41 @@ copy(cmake_cache
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
DSTS ${FLUID_INSTALL_DIR}) DSTS ${FLUID_INSTALL_DIR})
add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep}) # This command generates a complete fluid library for both train and inference
add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep})
# Following commands generate a inference-only fluid library
# third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR}
copy(third_party DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/third_party ${FLUID_INSTALL_DIR}/CMakeCache.txt
DSTS ${FLUID_INFERENCE_INSTALL_DIR} ${FLUID_INFERENCE_INSTALL_DIR}
)
# only need libpaddle_fluid.so/a and paddle_inference_api.h for inference-only library
copy(inference_api_lib DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_inference_api.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
)
add_custom_target(inference_lib_dist DEPENDS third_party inference_api_lib)
# paddle fluid version # paddle fluid version
execute_process( function(version version_file)
COMMAND ${GIT_EXECUTABLE} log --pretty=format:%H -1 execute_process(
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR} COMMAND ${GIT_EXECUTABLE} log --pretty=format:%H -1
OUTPUT_VARIABLE PADDLE_GIT_COMMIT) WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
set(version_file ${FLUID_INSTALL_DIR}/version.txt) OUTPUT_VARIABLE PADDLE_GIT_COMMIT)
file(WRITE ${version_file} file(WRITE ${version_file}
"GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n" "GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n"
"WITH_MKL: ${WITH_MKL}\n" "WITH_MKL: ${WITH_MKL}\n"
"WITH_GPU: ${WITH_GPU}\n") "WITH_MKLDNN: ${WITH_MKLDNN}\n"
if(WITH_GPU) "WITH_GPU: ${WITH_GPU}\n")
file(APPEND ${version_file} if(WITH_GPU)
"CUDA version: ${CUDA_VERSION}\n" file(APPEND ${version_file}
"CUDNN version: v${CUDNN_MAJOR_VERSION}\n") "CUDA version: ${CUDA_VERSION}\n"
endif() "CUDNN version: v${CUDNN_MAJOR_VERSION}\n")
endif()
endfunction()
version(${FLUID_INSTALL_DIR}/version.txt)
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt)
...@@ -89,7 +89,9 @@ CHECK_CXX_SOURCE_RUNS(" ...@@ -89,7 +89,9 @@ CHECK_CXX_SOURCE_RUNS("
#include <immintrin.h> #include <immintrin.h>
int main() int main()
{ {
__m512i a = _mm512_undefined_epi32(); __m512i a = _mm512_set_epi32 (-1, 2, -3, 4, -1, 2, -3, 4,
13, -5, 6, -7, 9, 2, -6, 3);
__m512i result = _mm512_abs_epi32 (a);
return 0; return 0;
}" AVX512F_FOUND) }" AVX512F_FOUND)
......
...@@ -24,6 +24,7 @@ if(NOT WITH_FLUID_ONLY) ...@@ -24,6 +24,7 @@ if(NOT WITH_FLUID_ONLY)
endif() endif()
add_subdirectory(testing) add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
if(NOT MOBILE_INFERENCE AND NOT RPI AND NOT WITH_C_API) if(NOT MOBILE_INFERENCE AND NOT RPI AND NOT WITH_C_API)
add_subdirectory(fluid) add_subdirectory(fluid)
endif() endif()
此差异已折叠。
...@@ -9,9 +9,6 @@ add_subdirectory(pybind) ...@@ -9,9 +9,6 @@ add_subdirectory(pybind)
add_subdirectory(recordio) add_subdirectory(recordio)
endif(NOT WIN32) endif(NOT WIN32)
if(WITH_INFERENCE) # NOTE: please add subdirectory inference at last.
# NOTE: please add subdirectory inference at last. add_subdirectory(inference)
add_subdirectory(inference)
endif()
add_subdirectory(train) add_subdirectory(train)
...@@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { ...@@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
case proto::AttrType::LONG: { case proto::AttrType::LONG: {
return attr_desc.l(); return attr_desc.l();
} }
case proto::AttrType::LONGS: {
std::vector<int64_t> val(attr_desc.longs_size());
for (int i = 0; i < attr_desc.longs_size(); ++i) {
val[i] = attr_desc.longs(i);
}
return val;
}
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
} }
......
...@@ -26,6 +26,113 @@ limitations under the License. */ ...@@ -26,6 +26,113 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<std::vector<int64_t>> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
std::vector<int64_t>* operator()(Attribute& attr) const {
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
std::vector<int> val = boost::get<std::vector<int>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
std::vector<float> val = boost::get<std::vector<float>>(attr);
std::vector<int64_t> vec(val.begin(), val.end());
attr = vec;
}
std::vector<int64_t>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
...@@ -42,7 +149,11 @@ class AttrReader { ...@@ -42,7 +149,11 @@ class AttrReader {
inline const T& Get(const std::string& name) const { inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name); name);
return boost::get<T>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
} }
private: private:
...@@ -82,7 +193,7 @@ class DefaultValueSetter { ...@@ -82,7 +193,7 @@ class DefaultValueSetter {
public: public:
explicit DefaultValueSetter(T default_value) explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {} : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; } void operator()(T& value) const { value = default_value_; } // NOLINT
private: private:
T default_value_; T default_value_;
...@@ -117,84 +228,6 @@ class EnumInContainer { ...@@ -117,84 +228,6 @@ class EnumInContainer {
std::unordered_set<T> container_; std::unordered_set<T> container_;
}; };
template <typename T>
struct ExtractAttribute {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
T* operator()(Attribute& attr) const {
T* attr_value = nullptr;
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// special handle bool
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
// hard to change the logic there. In another way, we should correct handle
// if the user set `some_flag=1`.
//
// FIX ME anytime if there is a better solution.
template <>
struct ExtractAttribute<bool> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
bool* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<bool>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
float val = boost::get<float>(attr);
attr = static_cast<bool>(val);
}
bool* attr_value = nullptr;
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<int64_t> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
int64_t* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<int64_t>(val);
} else if (attr.type() == typeid(float)) { // NOLINT
int val = boost::get<float>(attr);
attr = static_cast<int64_t>(val);
}
int64_t* attr_value = nullptr;
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
// check whether a certain attribute fit its limits // check whether a certain attribute fit its limits
// an attribute can have more than one limits // an attribute can have more than one limits
template <typename T> template <typename T>
...@@ -235,7 +268,7 @@ class TypedAttrChecker { ...@@ -235,7 +268,7 @@ class TypedAttrChecker {
return *this; return *this;
} }
void operator()(AttributeMap& attr_map) const { void operator()(AttributeMap& attr_map) const { // NOLINT
if (!attr_map.count(attr_name_)) { if (!attr_map.count(attr_name_)) {
// user do not set this attr // user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(), PADDLE_ENFORCE(!default_value_setter_.empty(),
...@@ -271,7 +304,7 @@ class OpAttrChecker { ...@@ -271,7 +304,7 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>()); return *(checker.target<TypedAttrChecker<T>>());
} }
void Check(AttributeMap& attr_map) const { void Check(AttributeMap& attr_map) const { // NOLINT
for (const auto& checker : attr_checkers_) { for (const auto& checker : attr_checkers_) {
checker(attr_map); checker(attr_map);
} }
......
...@@ -18,8 +18,8 @@ namespace framework { ...@@ -18,8 +18,8 @@ namespace framework {
void TransDataDevice(const Tensor &in, const platform::Place &dst_place, void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
Tensor *out) { Tensor *out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place() VLOG(30) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(), in.place().which(), dst_place.which(),
......
...@@ -49,10 +49,10 @@ class TestOpWithKernel : public OperatorWithKernel { ...@@ -49,10 +49,10 @@ class TestOpWithKernel : public OperatorWithKernel {
OpKernelType GetExpectedKernelType( OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
if (Attr<bool>("use_gpu")) { if (Attr<bool>("use_gpu")) {
VLOG(3) << "force use gpu kernel"; VLOG(30) << "force use gpu kernel";
return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0)); return OpKernelType(proto::VarType::FP32, platform::CUDAPlace(0));
} else { } else {
VLOG(3) << "use default kernel"; VLOG(30) << "use default kernel";
return OpKernelType(proto::VarType::FP32, return OpKernelType(proto::VarType::FP32,
ctx.Input<Tensor>("input")->place()); ctx.Input<Tensor>("input")->place());
} }
...@@ -148,7 +148,7 @@ TEST(Operator, CPUtoGPU) { ...@@ -148,7 +148,7 @@ TEST(Operator, CPUtoGPU) {
// get output // get output
auto* output2 = scope.Var("OUT2"); auto* output2 = scope.Var("OUT2");
gpu_op->Run(scope, cuda_place); gpu_op->Run(scope, cuda_place);
VLOG(3) << "after gpu_op run"; VLOG(30) << "after gpu_op run";
// auto* output2_ptr = output2->Get<LoDTensor>().data<float>(); // auto* output2_ptr = output2->Get<LoDTensor>().data<float>();
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool& pool =
......
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
...@@ -16,32 +17,39 @@ if(WITH_GPU) ...@@ -16,32 +17,39 @@ if(WITH_GPU)
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else() else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor) variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif() endif()
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope)
if(WITH_GPU) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
if (WITH_GPU)
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif() endif()
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU) set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) if (WITH_GPU)
else() list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
endif() endif()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
...@@ -54,8 +62,9 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu ...@@ -54,8 +62,9 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu
# device_context reduce_op_handle ) # device_context reduce_op_handle )
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context) DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
cc_library(build_strategy SRCS build_strategy.cc DEPS cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass) fuse_elewise_add_act_pass multi_batch_merge_pass)
...@@ -34,7 +34,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -34,7 +34,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
nccl_ctxs_(ctxs) { nccl_ctxs_(ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
} }
} }
} }
...@@ -46,7 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, ...@@ -46,7 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif #endif
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
if (NoDummyInputSize() == 1) { if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
...@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>(); *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i]; auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name_); auto *var = scope.FindVar(out_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_[p]; auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>(); auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
......
...@@ -48,16 +48,27 @@ void BroadcastOpHandle::RunImpl() { ...@@ -48,16 +48,27 @@ void BroadcastOpHandle::RunImpl() {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>()); var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
} }
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
}
void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) {
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (UNLIKELY(!in_tensor.IsInitialized())) {
VLOG(30) << "in var " << in_var_handle.name_ << "not inited, return!";
return;
}
InitOutputValue(*in_var_handle, out_var_handles); InitOutputValue(in_var_handle, out_var_handles);
if (platform::is_cpu_place(in_tensor.place())) { if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) { for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) { if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue; continue;
} }
auto &out_p = out_var_handle->place_; auto &out_p = out_var_handle->place_;
...@@ -114,12 +125,12 @@ void BroadcastOpHandle::RunImpl() { ...@@ -114,12 +125,12 @@ void BroadcastOpHandle::RunImpl() {
} }
} }
if (!out_handle->IsTheSameVar(*in_var_handle)) { if (!out_handle->IsTheSameVar(in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_) auto out_var = var_scopes.at(in_var_handle.scope_idx_)
->FindVar(out_var_handles[0]->name_); ->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_, in_tensor, in_var_handle.place_,
*(dev_ctxes_.at(in_var_handle->place_)), *(dev_ctxes_.at(in_var_handle.place_)),
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
} }
}); });
......
...@@ -44,7 +44,8 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -44,7 +44,8 @@ struct BroadcastOpHandle : public OpHandleBase {
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) { for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); this->SetDeviceContext(platform::CUDAPlace(p_ctx.first),
p_ctx.second.ctx_.get());
} }
} }
} }
...@@ -61,7 +62,10 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -61,7 +62,10 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private: void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes);
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -12,232 +12,12 @@ ...@@ -12,232 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle_test.h"
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
namespace f = paddle::framework;
namespace p = paddle::platform;
// test data amount
const f::DDim kDims = {20, 20};
struct TestBroadcastOpHandle {
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_;
Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_;
std::vector<p::Place> gpu_list_;
bool use_gpu_;
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif
void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) {
ctxs_[j]->Wait();
}
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_) {
nccl_ctxs_->WaitAll();
}
#endif
}
void InitCtxOnGpu(bool use_gpu) {
use_gpu_ = use_gpu;
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
int count = p::GetCUDADeviceCount();
if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is "
<< count;
exit(0);
}
for (int i = 0; i < count; ++i) {
auto p = p::CUDAPlace(i);
gpu_list_.push_back(p);
ctxs_.emplace_back(new p::CUDADeviceContext(p));
}
nccl_ctxs_.reset(new platform::NCCLContextMap(gpu_list_));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
int count = 8;
for (int i = 0; i < count; ++i) {
auto p = p::CPUPlace();
gpu_list_.push_back(p);
ctxs_.emplace_back(new p::CPUDeviceContext(p));
}
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_.reset(nullptr);
#endif
}
}
void InitBroadcastOp(size_t input_scope_idx) {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("out");
param_scopes_.emplace_back(&local_scope);
}
param_scopes_[input_scope_idx]->Var("input");
std::unique_ptr<ir::Node> n =
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation);
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_.reset(new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_,
nccl_ctxs_.get()));
#else
op_handle_.reset(
new BroadcastOpHandle(n.get(), local_scopes_, gpu_list_));
#endif
}
std::unique_ptr<ir::Node> v =
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable);
auto* in_var_handle = new VarHandle(v.get(), 1, input_scope_idx, "input",
gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
std::unique_ptr<ir::Node> v2 =
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v2.get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) {
if (!use_gpu_) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
}
std::unique_ptr<ir::Node> v3 =
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable);
VarHandle* out_var_handle =
new VarHandle(v3.get(), 2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
// add dummy var
std::unique_ptr<ir::Node> v4 =
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable);
vars_.emplace_back(new DummyVarHandle(v4.get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get());
out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle);
}
void TestBroadcastLodTensor(size_t input_scope_idx) {
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_lod_tensor = in_var->GetMutable<f::LoDTensor>();
in_lod_tensor->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
std::vector<float> send_vector(static_cast<size_t>(f::product(kDims)));
for (size_t k = 0; k < send_vector.size(); ++k) {
send_vector[k] = k;
}
f::LoD lod{{0, 10, 20}};
paddle::framework::TensorFromVector<float>(
send_vector, *(ctxs_[input_scope_idx]), in_lod_tensor);
in_lod_tensor->set_lod(lod);
in_lod_tensor->Resize(kDims);
op_handle_->Run(false);
WaitAll();
p::CPUPlace cpu_place;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
auto out_var = param_scopes_[j]->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_tensor = out_var->Get<f::LoDTensor>();
PADDLE_ENFORCE_EQ(out_tensor.lod(), lod, "lod is not equal.");
f::Tensor result_tensor;
f::TensorCopySync(out_tensor, cpu_place, &result_tensor);
float* ct = result_tensor.mutable_data<float>(cpu_place);
for (int64_t i = 0; i < f::product(kDims); ++i) {
ASSERT_NEAR(ct[i], send_vector[i], 1e-5);
}
}
}
void TestBroadcastSelectedRows(size_t input_scope_idx) {
auto in_var = param_scopes_[input_scope_idx]->FindVar("input");
PADDLE_ENFORCE_NOT_NULL(in_var);
auto in_selected_rows = in_var->GetMutable<f::SelectedRows>();
auto value = in_selected_rows->mutable_value();
value->mutable_data<float>(kDims, gpu_list_[input_scope_idx]);
int height = static_cast<int>(kDims[0]) * 2;
std::vector<int64_t> rows{0, 1, 2, 3, 3, 0, 14, 7, 3, 1,
2, 4, 6, 3, 1, 1, 1, 1, 3, 7};
in_selected_rows->set_height(height);
in_selected_rows->set_rows(rows);
std::vector<float> send_vector(static_cast<size_t>(f::product(kDims)));
for (size_t k = 0; k < send_vector.size(); ++k) {
send_vector[k] = k;
}
paddle::framework::TensorFromVector<float>(
send_vector, *(ctxs_[input_scope_idx]), value);
op_handle_->Run(false);
WaitAll();
p::CPUPlace cpu_place;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
auto out_var = param_scopes_[j]->FindVar("out");
PADDLE_ENFORCE_NOT_NULL(out_var);
auto& out_select_rows = out_var->Get<f::SelectedRows>();
auto rt = out_select_rows.value();
PADDLE_ENFORCE_EQ(out_select_rows.height(), height,
"height is not equal.");
for (size_t k = 0; k < out_select_rows.rows().size(); ++k) {
PADDLE_ENFORCE_EQ(out_select_rows.rows()[k], rows[k]);
}
f::Tensor result_tensor;
f::TensorCopySync(rt, cpu_place, &result_tensor);
float* ct = result_tensor.data<float>();
for (int64_t i = 0; i < f::product(kDims); ++i) {
ASSERT_NEAR(ct[i], send_vector[i], 1e-5);
}
}
}
};
TEST(BroadcastTester, TestCPUBroadcastTestLodTensor) { TEST(BroadcastTester, TestCPUBroadcastTestLodTensor) {
TestBroadcastOpHandle test_op; TestBroadcastOpHandle test_op;
size_t input_scope_idx = 0; size_t input_scope_idx = 0;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
namespace f = paddle::framework;
namespace p = paddle::platform;
// test data amount
const f::DDim kDims = {20, 20};
struct TestBroadcastOpHandle {
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_;
Scope g_scope_;
OpHandleBase* op_handle_;
std::vector<VarHandleBase*> vars_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
std::vector<p::Place> place_list_;
bool use_gpu_;
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif
void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) {
ctxs_[j]->Wait();
}
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_) {
nccl_ctxs_->WaitAll();
}
#endif
}
void InitCtxOnGpu(bool use_gpu) {
use_gpu_ = use_gpu;
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
int count = p::GetCUDADeviceCount();
if (count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu Broadcast, because the CUDA "
"device count is "
<< count;
exit(0);
}
for (int i = 0; i < count; ++i) {
auto p = p::CUDAPlace(i);
place_list_.push_back(p);
ctxs_.emplace_back(new p::CUDADeviceContext(p));
}
nccl_ctxs_.reset(new platform::NCCLContextMap(place_list_));
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
int count = 8;
for (int i = 0; i < count; ++i) {
auto p = p::CPUPlace();
place_list_.push_back(p);
ctxs_.emplace_back(new p::CPUDeviceContext(p));
}
#ifdef PADDLE_WITH_CUDA
nccl_ctxs_.reset(nullptr);
#endif
}
}
void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("out");
param_scopes_.emplace_back(&local_scope);
}
param_scopes_[input_scope_idx]->Var("input");
nodes_.emplace_back(
ir::CreateNodeForTest("node0", ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get());
#else
PADDLE_THROW("CUDA is not support.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_, nccl_ctxs_.get());
#else
op_handle_ = new BroadcastOpHandle(nodes_.back().get(), local_scopes_,
place_list_);
#endif
}
nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
"input", place_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back());
dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < place_list_.size(); ++j) {
if (!use_gpu_) {
op_handle_->SetDeviceContext(place_list_[j], ctxs_[j].get());
}
nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable));
VarHandle* out_var_handle =
new VarHandle(nodes_.back().get(), 2, j, "out", place_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
// add dummy var
nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable));
vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* out_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back());
out_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddOutput(out_dummy_var_handle);
}
std::vector<float> InitLoDTensor(const std::string& varname,
size_t input_scope_idx, const f::LoD& lod,
float val_scalar = 0.0) {
auto var = param_scopes_[input_scope_idx]->FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var);
auto lod_tensor = var->GetMutable<f::LoDTensor>();
std::vector<float> send_vector(static_cast<size_t>(f::product(kDims)));
for (size_t k = 0; k < send_vector.size(); ++k) {
send_vector[k] = k + val_scalar;
}
paddle::framework::TensorFromVector<float>(
send_vector, *(ctxs_[input_scope_idx]), lod_tensor);
lod_tensor->set_lod(lod);
lod_tensor->Resize(kDims);
return send_vector;
}
std::vector<float> InitSelectedRows(const std::string& varname,
size_t input_scope_idx,
const std::vector<int64_t>& rows,
int height, float value_scalar = 0.0) {
std::vector<float> send_vector(static_cast<size_t>(f::product(kDims)));
for (size_t k = 0; k < send_vector.size(); ++k) {
send_vector[k] = k + value_scalar;
}
auto var = param_scopes_[input_scope_idx]->FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var);
auto selected_rows = var->GetMutable<f::SelectedRows>();
auto value = selected_rows->mutable_value();
value->mutable_data<float>(kDims, place_list_[input_scope_idx]);
selected_rows->set_height(height);
selected_rows->set_rows(rows);
paddle::framework::TensorFromVector<float>(
send_vector, *(ctxs_[input_scope_idx]), value);
return send_vector;
}
void SelectedRowsEqual(const std::string& varname, int input_scope_idx,
const std::vector<float>& send_vector,
const std::vector<int64_t>& rows, int height) {
auto var = param_scopes_[input_scope_idx]->FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var);
auto& selected_rows = var->Get<f::SelectedRows>();
auto rt = selected_rows.value();
PADDLE_ENFORCE_EQ(selected_rows.height(), height, "height is not equal.");
for (size_t k = 0; k < selected_rows.rows().size(); ++k) {
PADDLE_ENFORCE_EQ(selected_rows.rows()[k], rows[k]);
}
p::CPUPlace cpu_place;
f::Tensor result_tensor;
f::TensorCopySync(rt, cpu_place, &result_tensor);
float* ct = result_tensor.data<float>();
for (int64_t i = 0; i < f::product(kDims); ++i) {
ASSERT_NEAR(ct[i], send_vector[i], 1e-5);
}
}
void LoDTensorEqual(const std::string& varname,
const std::vector<float>& send_vec, const f::LoD& lod,
framework::Scope* scope) {
p::CPUPlace cpu_place;
auto var = scope->FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<f::LoDTensor>();
PADDLE_ENFORCE_EQ(tensor.lod(), lod, "lod is not equal.");
f::Tensor result_tensor;
f::TensorCopySync(tensor, cpu_place, &result_tensor);
float* ct = result_tensor.mutable_data<float>(cpu_place);
for (int64_t k = 0; k < f::product(kDims); ++k) {
ASSERT_NEAR(ct[k], send_vec[k], 1e-5);
}
}
void TestBroadcastLodTensor(size_t input_scope_idx) {
f::LoD lod{{0, 10, 20}};
auto send_vector = InitLoDTensor("input", input_scope_idx, lod);
op_handle_->Run(false);
WaitAll();
for (size_t j = 0; j < place_list_.size(); ++j) {
LoDTensorEqual("out", send_vector, lod, param_scopes_[j]);
}
}
void TestBroadcastSelectedRows(size_t input_scope_idx) {
std::vector<int64_t> rows{0, 1, 2, 3, 3, 0, 14, 7, 3, 1,
2, 4, 6, 3, 1, 1, 1, 1, 3, 7};
int height = static_cast<int>(kDims[0] * 2);
auto send_vector = InitSelectedRows("input", input_scope_idx, rows, height);
op_handle_->Run(false);
WaitAll();
for (size_t j = 0; j < place_list_.size(); ++j) {
SelectedRowsEqual("out", input_scope_idx, send_vector, rows, height);
}
}
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
...@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public: public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) { : ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
}
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass"); auto viz_pass = AppendPass("graph_viz_pass");
...@@ -64,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -64,6 +69,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Verify that the graph is correct for multi-device executor. // Verify that the graph is correct for multi-device executor.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
if (strategy_.remove_unnecessary_lock_) {
AppendPass("modify_op_lock_and_record_event_pass");
}
} }
private: private:
...@@ -110,6 +119,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -110,6 +119,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs"); pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "sequential_execution_pass") {
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} }
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
} }
...@@ -121,6 +135,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -121,6 +135,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass); USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass); USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
USE_PASS(modify_op_lock_and_record_event_pass);
...@@ -69,6 +69,12 @@ struct BuildStrategy { ...@@ -69,6 +69,12 @@ struct BuildStrategy {
bool enable_data_balance_{false}; bool enable_data_balance_{false};
bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false};
bool remove_unnecessary_lock_{false};
// User normally doesn't need to call this API. // User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes // The PassBuilder allows for more customized insert, remove of passes
// from python side. // from python side.
......
...@@ -29,15 +29,21 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -29,15 +29,21 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
this->RunAndRecordEvent([this] { auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}); };
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
} }
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) { bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait = bool need_wait =
in_var && in_var->GeneratedOp() && in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_]; in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_);
return need_wait; return need_wait;
} }
......
...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -45,6 +47,7 @@ struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
bool is_lock_and_record_event_free_{false};
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -28,7 +28,7 @@ DataBalanceOpHandle::DataBalanceOpHandle( ...@@ -28,7 +28,7 @@ DataBalanceOpHandle::DataBalanceOpHandle(
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) { : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) { if (ctxs) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p); this->SetDeviceContext(p, ctxs->DevCtx(p));
} }
} }
} }
...@@ -89,8 +89,8 @@ void DataBalanceOpHandle::RunImpl() { ...@@ -89,8 +89,8 @@ void DataBalanceOpHandle::RunImpl() {
PADDLE_ENFORCE_GT(places_.size(), 1, PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of " "Data balance can only be enabled when the number of "
"places to run larger than 1."); "places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(), in_var_handles.size(), out_var_handles.size(),
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cstddef> // for size_t
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -26,6 +27,7 @@ struct ExecutionStrategy { ...@@ -26,6 +27,7 @@ struct ExecutionStrategy {
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100}; size_t num_iteration_per_drop_scope_{100};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false};
}; };
} // namespace details } // namespace details
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -32,13 +33,11 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
pool_(strategy.num_threads_), pool_(strategy.num_threads_),
prepare_pool_(1), // add one more thread for generate op_deps prepare_pool_(1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
auto &ops = graph_->Get<details::GraphOps>("ops"); for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
for (auto &op : ops) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op.get(), dep); op_deps_.emplace(op, dep);
if (dep == 0) { if (dep == 0) {
bootstrap_ops_.emplace_back(op.get()); bootstrap_ops_.emplace_back(op);
} }
} }
...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -54,13 +53,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
...@@ -92,13 +91,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -92,13 +91,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
size_t num_complete = 0; size_t num_complete = 0;
remaining_ = 0; remaining_ = 0;
BlockingQueue<size_t> complete_q; auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, &complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
while (num_complete != op_deps->size()) { while (num_complete != op_deps->size()) {
size_t num_comp = complete_q.Pop(); size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) { if (num_comp == -1UL) {
int remaining = 0; int remaining = 0;
while (true) { while (true) {
...@@ -107,10 +106,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -107,10 +106,13 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
break; break;
} }
for (int i = 0; i < remaining; ++i) { for (int i = 0; i < remaining; ++i) {
complete_q.Pop(); complete_q->Pop();
} }
} }
exception_.ReThrow(); if (exception_.IsCaught()) {
ClearFetchOp(graph_.get(), &fetch_ops);
exception_.ReThrow();
}
} }
num_complete += num_comp; num_complete += num_comp;
} }
...@@ -120,14 +122,17 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -120,14 +122,17 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
} }
void FastThreadedSSAGraphExecutor::RunOpAsync( void FastThreadedSSAGraphExecutor::RunOpAsync(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, BlockingQueue<size_t> *complete_q) { OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
++remaining_; ++remaining_;
this->pool_.enqueue([=] { this->pool_.enqueue([=] {
OpHandleBase *op_to_run = op; OpHandleBase *op_to_run = op;
size_t complete = 0; size_t complete = 0;
while (op_to_run != nullptr) { while (op_to_run != nullptr) {
try { try {
op_to_run->Run(strategy_.use_cuda_); if (LIKELY(!strategy_.dry_run_)) {
op_to_run->Run(strategy_.use_cuda_);
}
++complete; ++complete;
} catch (...) { } catch (...) {
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
...@@ -144,7 +149,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -144,7 +149,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
if (op_to_run == nullptr) { if (op_to_run == nullptr) {
op_to_run = pending_op; op_to_run = pending_op;
} else { } else {
this->RunOpAsync(op_deps, pending_op, complete_q); RunOpAsync(op_deps, pending_op, complete_q);
} }
} }
} }
......
...@@ -51,7 +51,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,7 +51,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::atomic<int> remaining_; std::atomic<int> remaining_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, BlockingQueue<size_t> *complete_q); OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps(); void PrepareAtomicOpDeps();
......
...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, ...@@ -28,11 +28,7 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
offset_(offset), offset_(offset),
local_scopes_(local_scopes) {} local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() { FetchOpHandle::~FetchOpHandle() {}
for (auto *input_var : inputs_) {
input_var->RemoveOutput(this, this->Node());
}
}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
void FusedBroadcastOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
if (places_.size() == 1UL) return;
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
for (size_t i = 0; i < in_var_handles.size(); ++i) {
BroadcastOneVar(
*in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num),
var_scopes);
}
}
std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
struct FusedBroadcastOpHandle : public BroadcastOpHandle {
public:
#ifdef PADDLE_WITH_CUDA
FusedBroadcastOpHandle(ir::Node *node,
const std::vector<Scope *> local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctx)
: BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {}
#else
FusedBroadcastOpHandle(ir::Node* node, const std::vector<Scope*> local_scopes,
const std::vector<platform::Place>& places)
: BroadcastOpHandle(node, local_scopes, places) {}
#endif
std::string Name() const override;
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/broadcast_op_handle_test.h"
namespace paddle {
namespace framework {
namespace details {
struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
std::vector<std::string> out_varnames_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var
for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
for (size_t j = 0; j < input_scope_idxes.size(); ++j) {
local_scope.Var("out_var" + j);
if (i == j) local_scope.Var("in_var" + j);
}
param_scopes_.emplace_back(&local_scope);
}
// create op handle node
nodes_.emplace_back(
ir::CreateNodeForTest("fused_broadcast", ir::Node::Type::kOperation));
if (use_gpu_) {
#ifdef PADDLE_WITH_CUDA
op_handle_ = new FusedBroadcastOpHandle(
nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else
PADDLE_THROW("CUDA is not supported.");
#endif
} else {
#ifdef PADDLE_WITH_CUDA
op_handle_ = new FusedBroadcastOpHandle(
nodes_.back().get(), local_scopes_, place_list_, nccl_ctxs_.get());
#else
op_handle_ = new FusedBroadcastOpHandle(nodes_.back().get(),
local_scopes_, place_list_);
#endif
}
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle
nodes_.emplace_back(
ir::CreateNodeForTest("in_node" + i, ir::Node::Type::kVariable));
VarHandle* in_var_handle =
new VarHandle(nodes_.back().get(), 1, input_scope_idxes[i],
"in_var" + i, place_list_[input_scope_idxes[i]]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add output var handle
for (size_t j = 0; j < place_list_.size(); ++j) {
nodes_.emplace_back(
ir::CreateNodeForTest("out_node" + i, ir::Node::Type::kVariable));
VarHandle* out_var_handle = new VarHandle(
nodes_.back().get(), 2, j, "out_var" + i, place_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
}
}
void TestFusedBroadcastLoDTensor(std::vector<size_t> input_scope_idxes) {
std::vector<std::vector<float>> send_vec;
f::LoD lod{{0, 10, 20}};
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
const std::string varname("in_var" + i);
float val_scalar = static_cast<float>(i);
send_vec.push_back(
InitLoDTensor(varname, input_scope_idxes[i], lod, val_scalar));
}
op_handle_->Run(false);
WaitAll();
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
const std::string& varname("out_var" + i);
for (size_t j = 0; j < place_list_.size(); ++j) {
LoDTensorEqual(varname, send_vec[i], lod, param_scopes_[j]);
}
}
}
void TestFusedBroadcastSelectedRows(std::vector<size_t> input_scope_idxes) {
std::vector<std::vector<float>> send_vector;
std::vector<int64_t> rows{0, 1, 2, 3, 3, 0, 14, 7, 3, 1,
2, 4, 6, 3, 1, 1, 1, 1, 3, 7};
int height = static_cast<int>(kDims[0] * 2);
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
const std::string varname("in_var" + i);
float val_scalar = static_cast<float>(i);
send_vector.push_back(InitSelectedRows(varname, input_scope_idxes[i],
rows, height, val_scalar));
}
op_handle_->Run(false);
WaitAll();
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
const std::string& varname("out_var" + i);
for (size_t j = 0; j < place_list_.size(); ++j) {
SelectedRowsEqual(varname, input_scope_idxes[i], send_vector[i], rows,
height);
}
}
}
};
TEST(FusedBroadcastTester, CPULodTensor) {
TestFusedBroadcastOpHandle test_op;
std::vector<size_t> input_scope_idxes = {0, 1};
test_op.InitCtxOnGpu(false);
test_op.InitFusedBroadcastOp(input_scope_idxes);
test_op.TestFusedBroadcastLoDTensor(input_scope_idxes);
}
TEST(FusedBroadcastTester, CPUSelectedRows) {
TestFusedBroadcastOpHandle test_op;
std::vector<size_t> input_scope_idxes = {0, 1};
test_op.InitCtxOnGpu(false);
test_op.InitFusedBroadcastOp(input_scope_idxes);
test_op.TestFusedBroadcastSelectedRows(input_scope_idxes);
}
#ifdef PADDLE_WITH_CUDA
TEST(FusedBroadcastTester, GPULodTensor) {
TestFusedBroadcastOpHandle test_op;
std::vector<size_t> input_scope_idxes = {0, 1};
test_op.InitCtxOnGpu(true);
test_op.InitFusedBroadcastOp(input_scope_idxes);
test_op.TestFusedBroadcastLoDTensor(input_scope_idxes);
}
TEST(FusedBroadcastTester, GPUSelectedRows) {
TestFusedBroadcastOpHandle test_op;
std::vector<size_t> input_scope_idxes = {0, 1};
test_op.InitCtxOnGpu(true);
test_op.InitFusedBroadcastOp(input_scope_idxes);
test_op.TestFusedBroadcastSelectedRows(input_scope_idxes);
}
#endif
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -36,7 +36,7 @@ void GatherOpHandle::RunImpl() { ...@@ -36,7 +36,7 @@ void GatherOpHandle::RunImpl() {
VarHandle *out_var_handle; VarHandle *out_var_handle;
{ {
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1, PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one."); "The number of output should be one.");
out_var_handle = out_var_handles.front(); out_var_handle = out_var_handles.front();
...@@ -99,7 +99,7 @@ void GatherOpHandle::RunImpl() { ...@@ -99,7 +99,7 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out_value->mutable_value(); Tensor *out_tensor = out_value->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_[out_var_handle->place_]; auto dev_ctx = dev_ctxes_.at(out_var_handle->place_);
RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx, RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
t_out_p] { t_out_p] {
int s = 0, e = 0; int s = 0, e = 0;
......
...@@ -31,9 +31,10 @@ struct TestGatherOpHandle { ...@@ -31,9 +31,10 @@ struct TestGatherOpHandle {
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> param_scopes_; std::vector<Scope*> param_scopes_;
Scope g_scope_; Scope g_scope_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase* op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase*> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<ir::Node>> nodes_;
void WaitAll() { void WaitAll() {
for (size_t j = 0; j < ctxs_.size(); ++j) { for (size_t j = 0; j < ctxs_.size(); ++j) {
...@@ -70,7 +71,7 @@ struct TestGatherOpHandle { ...@@ -70,7 +71,7 @@ struct TestGatherOpHandle {
} }
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes; nodes_.clear();
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
...@@ -82,44 +83,45 @@ struct TestGatherOpHandle { ...@@ -82,44 +83,45 @@ struct TestGatherOpHandle {
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_.reset( op_handle_ =
new GatherOpHandle(nodes.back().get(), local_scopes_, gpu_list_)); new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable).release());
auto* in_var_handle = auto* in_var_handle =
new VarHandle(nodes.back().get(), 1, j, "input", gpu_list_[j]); new VarHandle(nodes_.back().get(), 1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node2", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* in_dummy_var_handle = DummyVarHandle* in_dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
in_dummy_var_handle->ClearGeneratedOp(); in_dummy_var_handle->ClearGeneratedOp();
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node3", ir::Node::Type::kVariable).release());
auto* out_var_handle = new VarHandle(nodes.back().get(), 2, input_scope_idx, auto* out_var_handle =
"out", gpu_list_[input_scope_idx]); new VarHandle(nodes_.back().get(), 2, input_scope_idx, "out",
gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
nodes.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release()); ir::CreateNodeForTest("node4", ir::Node::Type::kVariable).release());
vars_.emplace_back(new DummyVarHandle(nodes.back().get())); vars_.emplace_back(new DummyVarHandle(nodes_.back().get()));
DummyVarHandle* dummy_var_handle = DummyVarHandle* dummy_var_handle =
static_cast<DummyVarHandle*>(vars_.back().get()); static_cast<DummyVarHandle*>(vars_.back());
op_handle_->AddOutput(dummy_var_handle); op_handle_->AddOutput(dummy_var_handle);
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsLockAndRecordEventFreeComputationOpHandle(
ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace())) return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
return false;
}
}
return true;
}
std::unique_ptr<ir::Graph> ModifyOpLockAndRecordEventPass::ApplyImpl(
std::unique_ptr<ir::Graph> ir_graph) const {
auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*ir_graph);
OpGraphView graph_view(all_ops);
for (auto &op : all_ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr) continue;
bool is_lock_and_record_event_free =
IsLockAndRecordEventFreeComputationOpHandle(compute_op, graph_view);
compute_op->SetLockAndRecordEventFree(is_lock_and_record_event_free);
if (is_lock_and_record_event_free) {
VLOG(100) << "Set is_lock_and_record_event_free be true in op "
<< compute_op->DebugString();
}
}
return ir_graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(modify_op_lock_and_record_event_pass,
paddle::framework::details::ModifyOpLockAndRecordEventPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
class ModifyOpLockAndRecordEventPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -36,20 +37,20 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair);
} }
} }
} }
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var);
} }
for (auto &op : graph->Get<GraphOps>(kGraphOps)) { for (OpHandleBase *op : ir::FilterByNodeWrapper<OpHandleBase>(*graph)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
pending_ops.insert({op.get(), op.get()->NoDupInputSize()}); pending_ops.insert({op, op->NoDupInputSize()});
} }
} }
...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const { ...@@ -89,6 +90,4 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
REGISTER_PASS(multi_devices_check_pass, REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker) paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars);
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const; const std::vector<ir::Node *> &nodes) const;
...@@ -61,14 +67,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -61,14 +67,17 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t num_places) const; size_t num_places) const;
void CreateScaleLossGradOp(ir::Graph *result, void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name) const; const std::string &loss_grad_name,
ir::Node *out_var_node) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
...@@ -78,6 +87,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -78,6 +87,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
void CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID( size_t GetAppropriateDeviceID(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -62,7 +63,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>(kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(graph)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
......
...@@ -35,23 +35,14 @@ namespace details { ...@@ -35,23 +35,14 @@ namespace details {
// The outside vector is the device vector. Each element of this vector is a // The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name, // map from variable name to variables. The variables, who have the same name,
// will have a differsent version. The offset in the // will have a differsent version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles. // `std::vector<VarHandle*>` is the version of varaibles.
typedef std::vector< typedef std::vector<std::unordered_map<std::string, std::vector<VarHandle*>>>
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars"; const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<VarHandleBase*> GraphDepVars;
const char kGraphDepVars[] = "dep_vars"; const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_graph_view.h"
#include <queue>
#include <utility>
namespace paddle {
namespace framework {
namespace details {
OpGraphView::OpGraphView(const std::vector<OpHandleBase *> &ops) { Build(ops); }
void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
for (auto &op : ops) {
preceding_ops_[op];
pending_ops_[op];
for (auto &var : op->Outputs()) {
for (auto &pending_op : var->PendingOps()) {
preceding_ops_[pending_op].insert(op);
pending_ops_[op].insert(pending_op);
}
}
}
PADDLE_ENFORCE(
preceding_ops_.size() == ops.size() && pending_ops_.size() == ops.size(),
"There are duplicate ops in graph.");
}
std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
std::unordered_set<OpHandleBase *> ret;
for (auto &pair : preceding_ops_) {
ret.insert(pair.first);
}
return ret;
}
bool OpGraphView::HasOp(OpHandleBase *op) const {
return preceding_ops_.count(op) != 0;
}
void OpGraphView::EnforceHasOp(OpHandleBase *op) const {
PADDLE_ENFORCE(HasOp(op), "Cannot find op %s in OpGraphView",
op == nullptr ? "nullptr" : op->DebugString());
}
const std::unordered_set<OpHandleBase *> &OpGraphView::PendingOps(
OpHandleBase *op) const {
EnforceHasOp(op);
return pending_ops_.at(op);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
namespace details {
class OpGraphView {
public:
explicit OpGraphView(const std::vector<OpHandleBase *> &ops);
std::unordered_set<OpHandleBase *> AllOps() const;
const std::unordered_set<OpHandleBase *> &PendingOps(OpHandleBase *op) const;
bool HasOp(OpHandleBase *op) const;
private:
void Build(const std::vector<OpHandleBase *> &ops);
void EnforceHasOp(OpHandleBase *op) const;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
preceding_ops_;
std::unordered_map<OpHandleBase *, std::unordered_set<OpHandleBase *>>
pending_ops_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -103,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() { ...@@ -103,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) { void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) { for (auto *in : inputs_) {
if (NeedWait(in)) { if (NeedWait(in)) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(place));
} }
} }
} }
......
...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; ...@@ -31,7 +31,10 @@ constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
class OpHandleBase { class OpHandleBase {
public: public:
explicit OpHandleBase(ir::Node *node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit OpHandleBase(ir::Node *node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~OpHandleBase(); virtual ~OpHandleBase();
...@@ -64,7 +67,8 @@ class OpHandleBase { ...@@ -64,7 +67,8 @@ class OpHandleBase {
virtual bool IsMultiDeviceTransfer() { return false; } virtual bool IsMultiDeviceTransfer() { return false; }
const platform::DeviceContext *DeviceContext(platform::Place place) { const platform::DeviceContext *DeviceContext(platform::Place place) {
return dev_ctxes_[place]; auto it = dev_ctxes_.find(place);
return it != dev_ctxes_.end() ? it->second : nullptr;
} }
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) { void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
......
...@@ -27,7 +27,7 @@ namespace framework { ...@@ -27,7 +27,7 @@ namespace framework {
namespace details { namespace details {
void ReduceOpHandle::RunImpl() { void ReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
if (places_.size() == 1) return; if (places_.size() == 1) return;
// the input and output may have dummy var. // the input and output may have dummy var.
......
...@@ -46,7 +46,8 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -46,7 +46,8 @@ struct ReduceOpHandle : public OpHandleBase {
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) { if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) { for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get(); this->SetDeviceContext(platform::CUDAPlace(p_ctx.first),
p_ctx.second.ctx_.get());
} }
} }
} }
......
...@@ -30,8 +30,8 @@ struct TestReduceOpHandle { ...@@ -30,8 +30,8 @@ struct TestReduceOpHandle {
Scope g_scope_; Scope g_scope_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> param_scopes_; std::vector<Scope *> param_scopes_;
std::unique_ptr<OpHandleBase> op_handle_; OpHandleBase *op_handle_;
std::vector<std::unique_ptr<VarHandleBase>> vars_; std::vector<VarHandleBase *> vars_;
std::vector<p::Place> gpu_list_; std::vector<p::Place> gpu_list_;
std::vector<std::unique_ptr<p::DeviceContext>> ctxs_; std::vector<std::unique_ptr<p::DeviceContext>> ctxs_;
......
...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_ = static_cast<platform::CUDADeviceContext *>( dev_ctx_ = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
PADDLE_ENFORCE(cudaSetDevice(place.device)); platform::SetDeviceId(place.device);
PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
} }
...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase { ...@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~ReferenceCountOpHandle() { ~ReferenceCountOpHandle() {
if (IsStreamGarabageCollector()) { if (IsStreamGarabageCollector()) {
auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(dev_ctx_->GetPlace());
PADDLE_ENFORCE(cudaSetDevice(gpu_place.device)); platform::SetDeviceId(gpu_place.device);
PADDLE_ENFORCE(cudaEventDestroy(event_)); PADDLE_ENFORCE(cudaEventDestroy(event_));
} }
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h" #include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -43,6 +44,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { ...@@ -43,6 +44,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return nullptr; return nullptr;
} }
static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out,
ir::Graph *graph) {
auto it = std::find_if(
in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) {
return dynamic_cast<DummyVarHandle *>(var) != nullptr;
});
if (it != in->Outputs().end()) {
out->AddInput(*it);
} else {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
in->AddOutput(dep_var);
out->AddInput(dep_var);
}
}
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount); auto &ref_cnts = Get<DeviceReferenceCountMap>(kGlobalReferenceCount);
...@@ -54,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -54,14 +72,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables // Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops // in computation ops
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
std::unordered_map<OpHandleBase *, std::unique_ptr<ReferenceCountOpHandle>> std::unordered_map<OpHandleBase *, ReferenceCountOpHandle *>
compute_ref_cnt_map; compute_ref_cnt_map;
auto get_ref_cnts_from_compute_op = [&]( auto get_ref_cnts_from_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) {
std::vector<std::string> var_names_in_op; std::vector<std::string> var_names_in_op;
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op == nullptr || if (compute_op == nullptr ||
!platform::is_gpu_place(compute_op->GetPlace())) !platform::is_gpu_place(compute_op->GetPlace()))
return var_names_in_op; return var_names_in_op;
...@@ -104,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -104,9 +121,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
}; };
auto update_ref_cnts_from_non_compute_op = [&]( auto update_ref_cnts_from_non_compute_op = [&](
const std::unique_ptr<OpHandleBase> &op, OpHandleBase *op, const std::vector<VarHandleBase *> &vars) {
const std::vector<VarHandleBase *> &vars) { if (dynamic_cast<ComputationOpHandle *>(op) != nullptr) return;
if (dynamic_cast<ComputationOpHandle *>(op.get()) != nullptr) return;
for (VarHandleBase *var_handle_base : vars) { for (VarHandleBase *var_handle_base : vars) {
auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base); auto *var_handle = dynamic_cast<VarHandle *>(var_handle_base);
if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue;
...@@ -124,8 +140,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -124,8 +140,8 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if (next_compute_op != nullptr) { if (next_compute_op != nullptr) {
if (compute_ref_cnt_map.count(next_compute_op)) { if (compute_ref_cnt_map.count(next_compute_op)) {
compute_ref_cnt_map[next_compute_op]->AddVar(var_name); compute_ref_cnt_map[next_compute_op]->AddVar(var_name);
VLOG(5) << "Add reference count of " << var_name << " to Operator " VLOG(50) << "Add reference count of " << var_name << " to Operator "
<< next_compute_op->Name(); << next_compute_op->Name();
} else { } else {
// Create new reference_count_op_handle // Create new reference_count_op_handle
ir::Node *ref_cnt_node = graph->CreateEmptyNode( ir::Node *ref_cnt_node = graph->CreateEmptyNode(
...@@ -133,40 +149,30 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -133,40 +149,30 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, ref_cnt_node, next_compute_op->GetScope(), place, {var_name},
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (next_compute_op->Outputs().empty()) { AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); compute_ref_cnt_map[next_compute_op] = ref_cnt_handle;
next_compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(next_compute_op->Outputs().front());
compute_ref_cnt_map[next_compute_op].reset(ref_cnt_handle);
} }
} }
} }
} }
}; };
auto &all_ops = graph->Get<GraphOps>(kGraphOps); auto all_ops = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
for (auto &op : all_ops) { for (auto &op : all_ops) {
auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs());
auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs());
if (in_var_names.empty() && out_var_names.empty()) continue; if (in_var_names.empty() && out_var_names.empty()) continue;
in_var_names.insert(in_var_names.end(), out_var_names.begin(), in_var_names.insert(in_var_names.end(), out_var_names.begin(),
out_var_names.end()); out_var_names.end());
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op.get()); auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace()); auto place = boost::get<platform::CUDAPlace>(compute_op->GetPlace());
ir::Node *ref_cnt_node = ir::Node *ref_cnt_node =
graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation);
auto *ref_cnt_handle = new ReferenceCountOpHandle( auto *ref_cnt_handle = new ReferenceCountOpHandle(
ref_cnt_node, compute_op->GetScope(), place, in_var_names, ref_cnt_node, compute_op->GetScope(), place, in_var_names,
gcs[place.device].get(), cur_ref_cnts[place.device].get()); gcs[place.device].get(), cur_ref_cnts[place.device].get());
if (compute_op->Outputs().empty()) { AddDependencyBetween(compute_op, ref_cnt_handle, graph.get());
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); compute_ref_cnt_map[compute_op] = ref_cnt_handle;
compute_op->AddOutput(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
ref_cnt_handle->AddInput(compute_op->Outputs().front());
compute_ref_cnt_map[compute_op].reset(ref_cnt_handle);
} }
for (auto &op : all_ops) { for (auto &op : all_ops) {
...@@ -174,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -174,11 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
update_ref_cnts_from_non_compute_op(op, op->Outputs()); update_ref_cnts_from_non_compute_op(op, op->Outputs());
} }
std::vector<std::unique_ptr<OpHandleBase>> new_all_ops; std::vector<OpHandleBase *> new_all_ops;
new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size());
for (auto &op : all_ops) { for (auto &op : all_ops) {
new_all_ops.emplace_back(std::move(op)); new_all_ops.emplace_back(std::move(op));
auto it = compute_ref_cnt_map.find(new_all_ops.back().get()); auto it = compute_ref_cnt_map.find(new_all_ops.back());
if (it != compute_ref_cnt_map.end()) { if (it != compute_ref_cnt_map.end()) {
// Add LeafNode to ReferenceCountOpHandle // Add LeafNode to ReferenceCountOpHandle
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
......
...@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, ...@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() if (ir::IsControlDepVar(*in->Node())) {
if (ir::IsControlDepVar(*in->Node())) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); this->RunAndRecordEvent([this] {
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
// lock. place_);
op_->Run(*tmp_scope, place_); });
} }
std::string RPCOpHandle::Name() const { return name_; } std::string RPCOpHandle::Name() const { return name_; }
......
...@@ -27,7 +27,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, ...@@ -27,7 +27,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
coeff_(static_cast<float>(1.0 / num_dev)), coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope), scope_(scope),
place_(place) { place_(place) {
dev_ctxes_[place_] = dev_ctx; this->SetDeviceContext(place_, dev_ctx);
} }
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
...@@ -46,12 +46,12 @@ void ScaleLossGradOpHandle::RunImpl() { ...@@ -46,12 +46,12 @@ void ScaleLossGradOpHandle::RunImpl() {
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
auto stream = auto stream = static_cast<platform::CUDADeviceContext *>(
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_]) this->dev_ctxes_.at(place_))
->stream(); ->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp, memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream); platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(10) << place_ << "RUN Scale loss grad op"; VLOG(100) << place_ << "RUN Scale loss grad op";
}); });
#endif #endif
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = Get<const std::vector<OpDesc *>>(kAllOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(100) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -19,14 +19,16 @@ namespace framework { ...@@ -19,14 +19,16 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
for (auto& in_var : op->Inputs()) {
in_var->RemoveOutput(op, op->Node());
}
graph->RemoveNode(op->Node()); graph->RemoveNode(op->Node());
} }
fetch_ops->clear(); fetch_ops->clear();
......
...@@ -38,8 +38,7 @@ class SSAGraphExecutor { ...@@ -38,8 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
std::vector<std::unique_ptr<FetchOpHandle>>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -39,7 +40,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -39,7 +40,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr)); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare", nullptr));
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_set<VarHandleBase *> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars; auto ready_vars = std::make_shared<BlockingQueue<VarHandleBase *>>();
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<OpHandleBase *> ready_ops;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule // streams from multiple GPUs, it's faster to buffer them and schedule
...@@ -51,34 +52,34 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -51,34 +52,34 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); InsertPendingVar(&pending_vars, ready_vars.get(), version_pair);
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get()); InsertPendingVar(&pending_vars, ready_vars.get(), var);
} }
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op);
} else { } else {
InsertPendingOp(&pending_ops, op.get()); InsertPendingOp(&pending_ops, op);
} }
} }
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; std::vector<FetchOpHandle *> fetch_ops;
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
&pending_vars, &ready_vars, &fetch_data); &pending_vars, ready_vars.get(), &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
running_ops_++; running_ops_++;
RunOp(&ready_vars, op); RunOp(ready_vars, op);
} }
set.clear(); set.clear();
}; };
...@@ -87,7 +88,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -87,7 +88,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear(); exception_holder_.Clear();
event.reset(nullptr); event.reset(nullptr);
// Step 3. Execution // Step 3. Execution
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
...@@ -103,13 +103,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -103,13 +103,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// 2. Find ready variable // 2. Find ready variable
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars.PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_.get(), &fetch_ops);
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} else { } else {
continue; continue;
...@@ -133,7 +134,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -133,7 +134,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_.get(), &fetch_ops);
...@@ -142,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -142,8 +142,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_set<VarHandleBase *> *pending_vars, std::unordered_set<VarHandleBase *> *pending_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) { BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
...@@ -153,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -153,7 +153,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
} }
} }
} }
...@@ -206,17 +206,20 @@ void ThreadedSSAGraphExecutor::InsertPendingVar( ...@@ -206,17 +206,20 @@ void ThreadedSSAGraphExecutor::InsertPendingVar(
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(100)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); VLOG(100) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
} }
op->Run(strategy_.use_cuda_); VLOG(100) << op << " " << op->Name() << " Done ";
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(100) << op << " " << op->Name() << "Signal posted";
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
......
...@@ -48,10 +48,10 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -48,10 +48,10 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
// Use topological sort algorithm // Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
~ThreadedSSAGraphExecutor() {} ~ThreadedSSAGraphExecutor() final = default;
private: private:
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -70,13 +70,13 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
BlockingQueue<VarHandleBase *> *ready_vars, BlockingQueue<VarHandleBase *> *ready_vars,
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps( void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
const std::vector<std::string> &fetch_tensors, std::vector<FetchOpHandle *> *fetch_ops,
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_set<VarHandleBase *> *pending_vars,
std::unordered_set<VarHandleBase *> *pending_vars, BlockingQueue<VarHandleBase *> *ready_vars,
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data); FeedFetchList *fetch_data);
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
......
...@@ -20,6 +20,8 @@ namespace details { ...@@ -20,6 +20,8 @@ namespace details {
VarHandleBase::~VarHandleBase() {} VarHandleBase::~VarHandleBase() {}
VarHandle::~VarHandle() { VLOG(4) << "deleting var handle " << DebugString(); }
std::string VarHandle::DebugString() const { std::string VarHandle::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << name_ << ":" << place_; ss << name_ << ":" << place_;
...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const { ...@@ -27,6 +29,10 @@ std::string VarHandle::DebugString() const {
} }
std::string DummyVarHandle::DebugString() const { return node_->Name(); } std::string DummyVarHandle::DebugString() const { return node_->Name(); }
DummyVarHandle::~DummyVarHandle() {
VLOG(4) << "deleting dummy var handle " << DebugString();
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -35,7 +35,10 @@ class OpHandleBase; ...@@ -35,7 +35,10 @@ class OpHandleBase;
// A variable can only be generated by a single operator. i.e. // A variable can only be generated by a single operator. i.e.
// This is a single assignment graph. // This is a single assignment graph.
struct VarHandleBase { struct VarHandleBase {
explicit VarHandleBase(ir::Node* node) : node_(node) {} // Owned by `node`. No need to be deleted explicitly.
explicit VarHandleBase(ir::Node* node) : node_(node) {
node_->WrappedBy(this);
}
virtual ~VarHandleBase(); virtual ~VarHandleBase();
...@@ -49,6 +52,8 @@ struct VarHandleBase { ...@@ -49,6 +52,8 @@ struct VarHandleBase {
void AddOutput(OpHandleBase* out, ir::Node* node) { void AddOutput(OpHandleBase* out, ir::Node* node) {
if (pending_ops_.find(out) == pending_ops_.end()) { if (pending_ops_.find(out) == pending_ops_.end()) {
PADDLE_ENFORCE(out != nullptr, "The output of %s should not be nullptr",
this->Node()->Name());
pending_ops_.insert(out); pending_ops_.insert(out);
node_->outputs.push_back(node); node_->outputs.push_back(node);
} }
...@@ -92,6 +97,8 @@ struct VarHandleBase { ...@@ -92,6 +97,8 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
explicit VarHandle(ir::Node* node) : VarHandleBase(node) {} explicit VarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~VarHandle();
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(ir::Node* node, size_t version, size_t scope_index, VarHandle(ir::Node* node, size_t version, size_t scope_index,
...@@ -119,6 +126,8 @@ struct VarHandle : public VarHandleBase { ...@@ -119,6 +126,8 @@ struct VarHandle : public VarHandleBase {
struct DummyVarHandle : public VarHandleBase { struct DummyVarHandle : public VarHandleBase {
explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {} explicit DummyVarHandle(ir::Node* node) : VarHandleBase(node) {}
virtual ~DummyVarHandle();
std::string DebugString() const override; std::string DebugString() const override;
}; };
......
...@@ -43,15 +43,52 @@ ExecutorPrepareContext::ExecutorPrepareContext( ...@@ -43,15 +43,52 @@ ExecutorPrepareContext::ExecutorPrepareContext(
} }
ExecutorPrepareContext::~ExecutorPrepareContext() { ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext"; VLOG(50) << "destroy ExecutorPrepareContext";
}
template <typename RefCntMap>
static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
GarbageCollector<Tensor>* gc,
RefCntMap* ref_cnts) {
std::unordered_set<Tensor*> erase_tensors;
auto handler = [&](const VariableNameMap& name_map) {
for (auto& name_pair : name_map) {
for (auto& name : name_pair.second) {
auto it = ref_cnts->find(name);
if (it == ref_cnts->end()) continue;
if ((it->second)-- == 1) {
auto* var = scope.FindVar(name);
if (var != nullptr) {
VLOG(100) << "Erase tensor \'" << name << "\'";
if (var->IsType<LoDTensor>()) {
erase_tensors.insert(var->GetMutable<LoDTensor>());
} else if (var->IsType<SelectedRows>()) {
erase_tensors.insert(
var->GetMutable<SelectedRows>()->mutable_value());
}
}
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
if (!erase_tensors.empty()) {
gc->Add(erase_tensors);
}
} }
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() { void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>(0)
->SendComplete(); ->SendComplete();
#endif #endif
} }
...@@ -66,7 +103,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) { ...@@ -66,7 +103,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
} else if (var_type == proto::VarType::FETCH_LIST) { } else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>(); var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) { } else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<framework::Scope>>(); var->GetMutable<std::vector<framework::Scope*>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) { } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
...@@ -104,21 +141,21 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, ...@@ -104,21 +141,21 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
if (var->Persistable()) { if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name()); auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(30) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { } else {
auto* ptr = scope->Var(var->Name()); auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(30) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
} }
} else { } else {
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name()); auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is " VLOG(30) << "Create variable " << var->Name() << ", which pointer is "
<< ptr; << ptr;
} }
} }
} }
...@@ -249,7 +286,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -249,7 +286,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
int i = 0; int i = 0;
for (auto& feed_target : (*feed_targets)) { for (auto& feed_target : (*feed_targets)) {
std::string var_name = feed_target.first; std::string var_name = feed_target.first;
VLOG(3) << "feed target's name: " << var_name; VLOG(30) << "feed target's name: " << var_name;
// prepend feed op // prepend feed op
auto* op = global_block->PrependOp(); auto* op = global_block->PrependOp();
...@@ -272,7 +309,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -272,7 +309,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
int i = 0; int i = 0;
for (auto& fetch_target : (*fetch_targets)) { for (auto& fetch_target : (*fetch_targets)) {
std::string var_name = fetch_target.first; std::string var_name = fetch_target.first;
VLOG(3) << "fetch target's name: " << var_name; VLOG(30) << "fetch target's name: " << var_name;
// append fetch op // append fetch op
auto* op = global_block->AppendOp(); auto* op = global_block->AppendOp();
...@@ -331,9 +368,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -331,9 +368,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
int64_t max_memory_size = GetEagerDeletionThreshold(); int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector<Tensor>> gc; std::unique_ptr<GarbageCollector<Tensor>> gc;
if (max_memory_size >= 0) { // WhileOp would set keep_kids to false
// WhileGradOp would need the scopes created in WhileOp
// Perhaps, we should not perform eager deletion in WhileOp
// The scopes and variables created by WhileOp would be deleted
// in WhileGradOp.
if (max_memory_size >= 0 && !keep_kids) {
ctx->ResetReferenceCount(); ctx->ResetReferenceCount();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_)) {
...@@ -352,50 +393,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -352,50 +393,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc != nullptr) { if (gc != nullptr) {
std::vector<std::string> erase_vars; DeleteUnusedTensors(*local_scope, op.get(), gc.get(),
for (auto& input : op->Inputs()) { &(ctx->cur_ref_cnts_));
for (auto& input_name : input.second) {
auto it = ctx->cur_ref_cnts_.find(input_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) { // should delete it
erase_vars.emplace_back(input_name);
ctx->cur_ref_cnts_.erase(input_name);
} else {
--(it->second);
}
}
}
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
auto it = ctx->cur_ref_cnts_.find(output_name);
if (it == ctx->cur_ref_cnts_.end()) continue;
if (it->second == 1) {
erase_vars.emplace_back(output_name);
ctx->cur_ref_cnts_.erase(output_name);
} else {
--(it->second);
}
}
}
if (!erase_vars.empty()) {
std::vector<framework::LoDTensor*> erase_tensors;
for (auto& name : erase_vars) {
auto* var = local_scope->FindVar(name);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
erase_tensors.push_back(tensor);
}
}
if (!erase_tensors.empty()) gc->Add(erase_tensors);
}
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(20) << "Memory used after operator " + op->Type() + " running: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
} }
} }
...@@ -420,10 +424,10 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -420,10 +424,10 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "-------------------------------------------------------"; VLOG(20) << "-------------------------------------------------------";
VLOG(2) << "Memory used after deleting local scope: " VLOG(20) << "Memory used after deleting local scope: "
<< memory::memory_usage(place_); << memory::memory_usage(place_);
VLOG(2) << "-------------------------------------------------------"; VLOG(20) << "-------------------------------------------------------";
} }
} }
...@@ -467,7 +471,7 @@ void Executor::RunPreparedContext( ...@@ -467,7 +471,7 @@ void Executor::RunPreparedContext(
void Executor::EnableMKLDNN(const ProgramDesc& program) { void Executor::EnableMKLDNN(const ProgramDesc& program) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
VLOG(3) << "use_mkldnn=True"; VLOG(30) << "use_mkldnn=True";
for (size_t bid = 0; bid < program.Size(); ++bid) { for (size_t bid = 0; bid < program.Size(); ++bid) {
auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid); auto* block = const_cast<ProgramDesc&>(program).MutableBlock(bid);
for (auto* op : block->AllOps()) { for (auto* op : block->AllOps()) {
......
...@@ -32,38 +32,32 @@ template <typename T> ...@@ -32,38 +32,32 @@ template <typename T>
std::unordered_map<std::string, T> GetNonPersistableReferenceCount( std::unordered_map<std::string, T> GetNonPersistableReferenceCount(
const ProgramDesc& prog, size_t block_id) { const ProgramDesc& prog, size_t block_id) {
auto& block = prog.Block(block_id); auto& block = prog.Block(block_id);
std::unordered_set<std::string> ignored_vars;
std::unordered_map<std::string, T> ref_cnts; std::unordered_map<std::string, T> ref_cnts;
for (auto var_desc : block.AllVars()) { auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) {
auto type = var_desc->Proto()->type().type(); for (auto& name_pair : name_map) {
if (type != proto::VarType::LOD_TENSOR || var_desc->Persistable()) { for (auto& name : name_pair.second) {
ignored_vars.insert(var_desc->Name()); // ignore persistable vars auto* var_desc = block.FindVar(name);
} if (var_desc == nullptr || var_desc->Persistable()) continue;
} auto type = var_desc->Proto()->type().type();
if (type != proto::VarType::LOD_TENSOR &&
for (auto op_desc : block.AllOps()) { type != proto::VarType::SELECTED_ROWS) {
for (auto& input : op_desc->Inputs()) { continue;
for (auto& input_name : input.second) {
if (!ignored_vars.count(input_name)) {
if (ref_cnts.count(input_name))
++ref_cnts[input_name];
else
ref_cnts[input_name] = 1;
} }
}
}
for (auto& output : op_desc->Outputs()) { auto it = ref_cnts.find(name);
for (auto output_name : output.second) { if (it != ref_cnts.end()) {
if (!ignored_vars.count(output_name)) { ++it->second;
if (ref_cnts.count(output_name)) } else {
++ref_cnts[output_name]; ref_cnts[name] = 1;
else
ref_cnts[output_name] = 1;
} }
} }
} }
};
for (auto op_desc : block.AllOps()) {
update_ref_cnts(op_desc, op_desc->Inputs());
update_ref_cnts(op_desc, op_desc->Outputs());
} }
return ref_cnts; return ref_cnts;
} }
......
...@@ -25,10 +25,9 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input, ...@@ -25,10 +25,9 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
const std::string& var_name, size_t index) { const std::string& var_name, size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will // If var_name Variable is not found in GlobalScope, a new variable will
// be created. // be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; VLOG(30) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name); Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs = auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
if (index >= feed_inputs.size()) { if (index >= feed_inputs.size()) {
feed_inputs.resize(index + 1); feed_inputs.resize(index + 1);
} }
...@@ -48,8 +47,8 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, ...@@ -48,8 +47,8 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
typeid(FeedFetchList).name()); typeid(FeedFetchList).name());
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>(); auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
auto& tensor = fetch_outputs[index]; auto& tensor = fetch_outputs[index];
VLOG(3) << "Fetch " << var_name << " with index " << index VLOG(30) << "Fetch " << var_name << " with index " << index
<< " shape= " << tensor.dims(); << " shape= " << tensor.dims();
PADDLE_ENFORCE_LT(index, fetch_outputs.size()); PADDLE_ENFORCE_LT(index, fetch_outputs.size());
return tensor; return tensor;
} }
......
...@@ -35,6 +35,7 @@ enum AttrType { ...@@ -35,6 +35,7 @@ enum AttrType {
BLOCK = 8; BLOCK = 8;
LONG = 9; LONG = 9;
BLOCKS = 10; BLOCKS = 10;
LONGS = 11;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -55,6 +56,7 @@ message OpDesc { ...@@ -55,6 +56,7 @@ message OpDesc {
optional int32 block_idx = 12; optional int32 block_idx = 12;
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14; repeated int32 blocks_idx = 14;
repeated int64 longs = 15;
}; };
message Var { message Var {
...@@ -80,7 +82,6 @@ message OpProto { ...@@ -80,7 +82,6 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ]; optional bool dispensable = 5 [ default = false ];
optional string reuse = 6;
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
...@@ -10,7 +10,7 @@ function(pass_library TARGET DEST) ...@@ -10,7 +10,7 @@ function(pass_library TARGET DEST)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass ${op_library_DEPS}) cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically. # add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
message(STATUS "add pass ${TARGET} ${DEST}") message(STATUS "add pass ${TARGET} ${DEST}")
...@@ -25,19 +25,27 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph) ...@@ -25,19 +25,27 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
if (WITH_MKLDNN)
pass_library(conv_relu_mkldnn_fuse_pass inference)
endif ()
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference) pass_library(infer_clean_graph_pass inference)
pass_library(fc_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference) pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
...@@ -45,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -45,6 +53,7 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
...@@ -52,5 +61,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph ...@@ -52,5 +61,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
endif () endif ()
...@@ -147,19 +147,19 @@ void PrepareParameters(Graph* graph, const Param& param) { ...@@ -147,19 +147,19 @@ void PrepareParameters(Graph* graph, const Param& param) {
scope->Var(param.LSTMX)->GetMutable<LoDTensor>(); scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>(); scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \ #define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \ VLOG(40) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \ VLOG(40) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \ VLOG(40) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \ << " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \ auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \ auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>(); auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W(forget); GATE_W(forget);
...@@ -208,7 +208,7 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0, ...@@ -208,7 +208,7 @@ void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
int D = W_forget_w0.dims()[0]; int D = W_forget_w0.dims()[0];
int M = W_forget_w1.dims()[0]; int M = W_forget_w1.dims()[0];
out->Resize(make_ddim({D + M, 4 * D})); out->Resize(make_ddim({D + M, 4 * D}));
VLOG(3) << "LSTMWeight resized to " << out->dims(); VLOG(30) << "LSTMWeight resized to " << out->dims();
float* out_data = out->mutable_data<float>(platform::CPUPlace()); float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors( std::array<const float*, 4> tensors(
...@@ -262,7 +262,7 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( ...@@ -262,7 +262,7 @@ std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unordered_set<std::string> specified_vars({"data_lod_attention", std::unordered_set<std::string> specified_vars({"data_lod_attention",
"cell_init", "hidden_init", "cell_init", "hidden_init",
"data", "week", "minute"}); "data", "week", "minute"});
int count = 0; size_t count = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (node->IsVar() && specified_vars.count(node->Name())) { if (node->IsVar() && specified_vars.count(node->Name())) {
++count; ++count;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
#include <functional>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename BinaryOperation>
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
BinaryOperation f) {
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
LoDTensor vec_y;
vec_y.Resize(vec_a.dims());
const float* a = vec_a.data<float>();
const float* b = vec_b.data<float>();
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < vec_a.numel(); i++) {
y[i] = f(a[i], b[i]);
}
return vec_y;
}
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input);
int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(40) << "handle ConvBias fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_bias_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_bias_pattern); // tmp
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_bias_pattern); // CONV op
// bias
GET_IR_NODE_FROM_SUBGRAPH(eltwise_bias, eltwise_bias, conv_bias_pattern);
// output
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bias_pattern);
// elementwise_add op
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bias_pattern);
PADDLE_ENFORCE(subgraph.count(conv_input));
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *eltwise);
if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) {
VLOG(30) << "do not perform conv+bias fuse";
return;
}
auto* eltwise_bias_tensor =
scope->FindVar(eltwise_bias->Name())->GetMutable<LoDTensor>();
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
auto conv_bias_names = conv->Op()->Input("Bias");
// add eltwise bias to existing conv bias
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), eltwise_bias_tensor->dims());
*conv_bias_tensor = tensor_apply_eltwise(
*conv_bias_tensor, *eltwise_bias_tensor, std::plus<float>());
conv->Op()->SetOutput("Output",
std::vector<std::string>({eltwise_out->Name()}));
GraphSafeRemoveNodes(graph.get(), {eltwise, conv_out});
IR_NODE_LINK_TO(conv, eltwise_out);
} else {
// take eltwise bias as conv bias
OpDesc desc;
desc.SetInput(
"Input", std::vector<std::string>({subgraph.at(conv_input)->Name()}));
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d");
for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
}
auto conv_bias_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_bias_node);
IR_NODE_LINK_TO(conv_weight, conv_bias_node);
IR_NODE_LINK_TO(eltwise_bias, conv_bias_node);
IR_NODE_LINK_TO(conv_bias_node, eltwise_out);
GraphSafeRemoveNodes(graph.get(), {conv, eltwise, conv_out});
}
found_conv_bias_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_bias_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and Elementwise_add to a ConvBiasOp.
*/
class ConvBiasFusePass : public FusePassBase {
public:
virtual ~ConvBiasFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
#include <functional>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_CONV_BN_NODES(pattern_name) \
/* OPERATORS */ \
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \
/* CONV inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
/* CONV outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
/* BN inputs */ \
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \
/* BN outputs */ \
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
void recompute_bias_and_weights(const Scope* scope,
ir::Node* conv_weight, //
const ir::Node& bn_scale, //
const LoDTensor& bn_bias_tensor, //
const ir::Node& bn_mean, //
const ir::Node& bn_variance, //
LoDTensor* eltwise_y_in_tensor, //
float epsilon) {
using EigenVectorArrayMap =
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
using EigenMatrixArrayMap = Eigen::Map<
Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Re-compute bias of conv2d from BN
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims());
auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
auto* variance_tensor =
scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();
ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
scale_tensor->numel(), 1);
EigenVectorArrayMap variance_array(
variance_tensor->mutable_data<float>(platform::CPUPlace()),
variance_tensor->numel(), 1);
ConstEigenVectorArrayMap mean_array(mean_tensor->data<float>(),
mean_tensor->numel(), 1);
ConstEigenVectorArrayMap bn_bias_array(bn_bias_tensor.data<float>(),
bn_bias_tensor.numel(), 1);
// variance will not be used anymore, so make it std_array and then tmp_array
variance_array += epsilon;
variance_array = variance_array.sqrt();
variance_array = scale_array / variance_array;
EigenVectorArrayMap eltwise_y_in_array(
eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 1);
eltwise_y_in_array =
((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
// Re-compute weight of conv2d from BN
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= variance_array;
}
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
conv_bn_pattern(conv_input, false /*with_eltwise_add*/);
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(40) << "handle ConvBN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
// bn_saved_variance
GET_CONV_BN_NODES(conv_bn_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
if (fuse_option == DO_NOT_FUSE) {
VLOG(30) << "do not perform conv+bn fuse";
return;
}
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
eltwise_y_in_desc.SetPersistable(true);
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
auto* eltwise_y_in_tensor =
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
// Get batch norm bias
auto* bn_bias_tensor =
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
// Initialize eltwise_y
eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
eltwise_y_in_tensor->numel(), 0.0f);
// update weights and biases
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon);
// with MKL-DNN fuse conv+bn into conv with bias
// without MKL-DNN fuse conv+bn into conv+elementwise_add
if (fuse_option == FUSE_MKLDNN) {
auto input_names = conv->Op()->InputNames();
bool has_bias = std::find(input_names.begin(), input_names.end(),
"Bias") != input_names.end();
if (has_bias && conv->Op()->Input("Bias").size() > 0) {
// reuse existing conv bias node
auto conv_bias_names = conv->Op()->Input("Bias");
PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1);
auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
eltwise_y_in_tensor->dims());
auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
} else {
// add new conv_bias node
conv->Op()->SetInput(
"Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
IR_NODE_LINK_TO(eltwise_y_in_node, conv);
}
conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv, bn_out);
found_conv_bn_count++;
} else { // fuse_option == FUSE_NATIVE
// create an elementwise add node.
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance});
IR_NODE_LINK_TO(conv_out, eltwise_op);
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
IR_NODE_LINK_TO(eltwise_op, bn_out);
found_conv_bn_count++;
}
};
gpd(graph.get(), handler);
AddStatis(found_conv_bn_count);
return graph;
}
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init(name_scope_, graph.get());
auto* scope = param_scope();
PADDLE_ENFORCE(scope);
GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
conv_bn_pattern(conv_input, true /*with_eltwise_add*/);
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(40) << "handle ConvBN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
GET_CONV_BN_NODES(conv_bn_pattern);
// OPERATORS
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
// BIAS inputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
// BIAS outputs
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);
// Get eltwise_y (conv bias) variable
auto* eltwise_y_in_tensor =
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
// Get batch norm bias
auto* bn_bias_tensor =
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
// update weights and biases
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
*bn_mean, *bn_variance, eltwise_y_in_tensor,
epsilon);
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
GraphSafeRemoveNodes(
graph.get(),
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
IR_NODE_LINK_TO(eltwise, bn_out);
found_conv_bn_count++;
};
gpd(graph.get(), handler);
AddStatis(found_conv_bn_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvEltwiseAddBNFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp.
*/
class ConvBNFusePass : public FusePassBase {
public:
virtual ~ConvBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bn_fuse"};
};
class ConvEltwiseAddBNFusePass : public FusePassBase {
public:
virtual ~ConvEltwiseAddBNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include <utility>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
namespace ir {
namespace {
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto from_in_inputs =
std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type;
std::for_each(std::begin(inputs), std::end(inputs),
[from, to, &node](const input_type& i) -> void {
auto param_names = i.second;
auto pi = std::find(std::begin(param_names),
std::end(param_names), from->Name());
if (pi != std::end(param_names)) {
node.Op()->SetInput(i.first, {to->Name()});
}
});
}
}
}
} // namespace
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_};
auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
elementwise_add_pattern(conv_output);
conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty();
if (has_bias) {
auto conv_bias_names = bias_it->second;
auto conv_bias_names_it =
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs),
[&conv_bias_names](Node* n) -> bool {
return n->Name() == conv_bias_names[0];
});
return std::make_pair(has_bias, *conv_bias_names_it);
}
}
return std::make_pair(false, nullptr);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
public:
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"residual_connections_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
int found_conv_relu_count = 0; int found_conv_relu_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle ConvReLU fuse"; VLOG(40) << "handle ConvReLU fuse";
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_relu_pattern); // Filter conv_relu_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
...@@ -46,6 +46,12 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl( ...@@ -46,6 +46,12 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
FuseOptions fuse_option = FindFuseOption(*conv, *relu);
if (fuse_option == DO_NOT_FUSE) {
VLOG(30) << "do not perform conv+relu fuse";
return;
}
// Transform Conv node into ConvReLU node. // Transform Conv node into ConvReLU node.
OpDesc* desc = conv->Op(); OpDesc* desc = conv->Op();
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()})); desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
......
...@@ -31,7 +31,8 @@ class ConvReLUFusePass : public FusePassBase { ...@@ -31,7 +31,8 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {} virtual ~ConvReLUFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,25 +15,30 @@ ...@@ -15,25 +15,30 @@
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) { const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
} else if (type == "relu") { } else if (type == "relu") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs); op->SetInput("X", inputs);
} }
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
} }
// a->OP0->b // a->OP0->b
...@@ -43,7 +48,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -43,7 +48,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
ProgramDesc BuildProgramDesc() { ProgramDesc BuildProgramDesc() {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v :
std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g"})) { std::vector<std::string>({"a", "b", "c", "weights", "bias", "f", "g",
"h", "weights2", "bias2", "k", "l"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
...@@ -51,14 +57,24 @@ ProgramDesc BuildProgramDesc() { ...@@ -51,14 +57,24 @@ ProgramDesc BuildProgramDesc() {
} }
} }
SetOp(&prog, "OP0", std::vector<std::string>({"a"}), SetOp(&prog, "OP0", "op0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"})); std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"b"}), SetOp(&prog, "OP1", "op1", std::vector<std::string>({"b"}),
std::vector<std::string>({"c"})); std::vector<std::string>({"c"}));
SetOp(&prog, "conv2d", std::vector<std::string>({"c", "weights", "bias"}), // conv+relu, both with MKL-DNN
std::vector<std::string>({"f"})); SetOp(&prog, "conv2d", "conv1",
SetOp(&prog, "relu", std::vector<std::string>({"f"}), std::vector<std::string>({"c", "weights", "bias"}),
std::vector<std::string>({"g"})); std::vector<std::string>({"f"}), true);
SetOp(&prog, "relu", "relu1", std::vector<std::string>({"f"}),
std::vector<std::string>({"g"}), true);
SetOp(&prog, "OP3", "op3", std::vector<std::string>({"g"}),
std::vector<std::string>({"h"}));
// conv+relu, only one with MKL-DNN
SetOp(&prog, "conv2d", "conv2",
std::vector<std::string>({"h", "weights2", "bias2"}),
std::vector<std::string>({"k"}), true);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"k"}),
std::vector<std::string>({"l"}));
return prog; return prog;
} }
...@@ -88,10 +104,16 @@ TEST(ConvReLUFusePass, basic) { ...@@ -88,10 +104,16 @@ TEST(ConvReLUFusePass, basic) {
auto* op = node->Op(); auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn")); ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn"))); EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_relu")); // check if only "conv1" convolution is fused
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu")); auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (fuse_relu) { if (op_name == "conv1") {
++conv_relu_count; ASSERT_TRUE(op->HasAttr("fuse_relu"));
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
if (fuse_relu) {
++conv_relu_count;
}
} else if (op_name == "conv2") {
ASSERT_FALSE(op->HasAttr("fuse_relu"));
} }
} }
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
pattern->NewNode("depthwise_conv")
->assert_is_op("depthwise_conv2d")
->assert_op_attr("use_mkldnn", true);
int found_depthwise_conv_mkldnn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(30) << "handle DepthwiseConvMKLDNN fuse";
GET_NODE(depthwise_conv, (*pattern));
depthwise_conv->Op()->SetType("conv2d");
found_depthwise_conv_mkldnn_count++;
};
gpd(graph.get(), handler);
AddStatis(found_depthwise_conv_mkldnn_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass);
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册