提交 25083de9 编写于 作者: C caoying03

Merge branch 'develop' into cross_entropy_over_beam

#!/bin/bash
set -e
readonly VERSION="3.8"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: local - repo: local
hooks: hooks:
- id: clang-format - id: clang-format-with-version-check
name: clang-format name: clang-format
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: clang-format -i entry: ./.clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang - repo: https://github.com/PaddlePaddle/pre-commit-golang
......
...@@ -36,8 +36,8 @@ include(simd) ...@@ -36,8 +36,8 @@ include(simd)
################################ Configurations ####################################### ################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND})
option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND})
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
...@@ -55,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) ...@@ -55,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON) option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -137,9 +138,9 @@ set(EXTERNAL_LIBS ...@@ -137,9 +138,9 @@ set(EXTERNAL_LIBS
) )
if(WITH_GPU) if(WITH_GPU)
list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY})
if(NOT WITH_DSO) if(NOT WITH_DSO)
list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY})
endif(NOT WITH_DSO) endif(NOT WITH_DSO)
endif(WITH_GPU) endif(WITH_GPU)
......
...@@ -34,9 +34,6 @@ RUN apt-get update && \ ...@@ -34,9 +34,6 @@ RUN apt-get update && \
net-tools && \ net-tools && \
apt-get clean -y apt-get clean -y
# paddle is using numpy.flip, which is introduced since 1.12.0
RUN pip --no-cache-dir install 'numpy>=1.12.0'
# Install Go and glide # Install Go and glide
RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \ RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \
tar -xz -C /usr/local && \ tar -xz -C /usr/local && \
...@@ -58,33 +55,22 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 ...@@ -58,33 +55,22 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter # FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter
# version util jupyter fixes this issue. # version util jupyter fixes this issue.
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \
pip install -U 'protobuf==3.1.0' && \ pip install -U wheel && \
pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \ pip install -U docopt PyYAML sphinx && \
pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \ pip install -U sphinx-rtd-theme==0.1.9 recommonmark
pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \
RUN pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python rarfile 'scipy>=0.19.0' 'nltk>=3.2.2' pip install opencv-python
COPY ./python/requirements.txt /root/
RUN pip install -r /root/requirements.txt
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
# the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 # the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2
RUN apt-get install -y libssl-dev libffi-dev RUN apt-get install -y libssl-dev libffi-dev
RUN pip install certifi urllib3[secure] RUN pip install certifi urllib3[secure]
# TODO(qijun) The template library Eigen doesn't work well with GCC 5
# coming with the default Docker image, so we switch to use GCC 4.8
# by default. And I will check Eigen library later.
RUN ln -sf gcc-4.8 /usr/bin/gcc && \
ln -sf gcc-ar-4.8 /usr/bin/gcc-ar && \
ln -sf gcc-nm-4.8 /usr/bin/gcc-nm && \
ln -sf gcc-ranlib-4.8 /usr/bin/gcc-ranlib && \
ln -sf gcc-4.8 /usr/bin/x86_64-linux-gnu-gcc && \
ln -sf gcc-ar-4.8 /usr/bin/x86_64-linux-gnu-gcc-ar && \
ln -sf gcc-nm-4.8 /usr/bin/x86_64-linux-gnu-gcc-nm && \
ln -sf gcc-ranlib-4.8 /usr/bin/x86_64-linux-gnu-gcc-ranlib && \
ln -sf g++-4.8 /usr/bin/g++ && \
ln -sf g++-4.8 /usr/bin/x86_64-linux-gnu-g++
# Install woboq_codebrowser to /woboq # Install woboq_codebrowser to /woboq
RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \ RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \
......
...@@ -28,6 +28,10 @@ if(NOT WITH_TIMER) ...@@ -28,6 +28,10 @@ if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER) endif(NOT WITH_TIMER)
if(USE_EIGEN_FOR_BLAS)
add_definitions(-DPADDLE_USE_EIGEN_FOR_BLAS)
endif(USE_EIGEN_FOR_BLAS)
if(NOT WITH_PROFILER) if(NOT WITH_PROFILER)
add_definitions(-DPADDLE_DISABLE_PROFILER) add_definitions(-DPADDLE_DISABLE_PROFILER)
endif(NOT WITH_PROFILER) endif(NOT WITH_PROFILER)
......
...@@ -2,7 +2,7 @@ if(NOT WITH_GPU) ...@@ -2,7 +2,7 @@ if(NOT WITH_GPU)
return() return()
endif() endif()
set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT") set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT")
find_path(CUDNN_INCLUDE_DIR cudnn.h find_path(CUDNN_INCLUDE_DIR cudnn.h
PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include
$ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE}
......
...@@ -73,10 +73,18 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) ...@@ -73,10 +73,18 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
# linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas)
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
ADD_LIBRARY(cblas STATIC ${dummyfile}) IF(${CBLAS_PROVIDER} MATCHES MKL)
ADD_LIBRARY(cblas SHARED ${dummyfile})
ELSE()
ADD_LIBRARY(cblas STATIC ${dummyfile})
ENDIF()
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES}) TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
IF(NOT ${CBLAS_FOUND}) IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas) ADD_DEPENDENCIES(cblas extern_openblas)
LIST(APPEND external_project_dependencies cblas) LIST(APPEND external_project_dependencies cblas)
ELSE()
IF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
ADD_DEPENDENCIES(cblas mklml)
ENDIF()
ENDIF(NOT ${CBLAS_FOUND}) ENDIF(NOT ${CBLAS_FOUND})
...@@ -9,13 +9,6 @@ function(CheckCompilerCXX11Flag) ...@@ -9,13 +9,6 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif() endif()
if(NOT ANDROID)
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
endif()
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers. # Apple Clang is a different compiler than upstream Clang which havs different version numbers.
...@@ -160,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) ...@@ -160,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread) LIST(APPEND CUDA_NVCC_FLAGS -std=c++11)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")
......
...@@ -362,6 +362,11 @@ trans ...@@ -362,6 +362,11 @@ trans
.. autoclass:: paddle.v2.layer.trans .. autoclass:: paddle.v2.layer.trans
:noindex: :noindex:
scale_shift
-----------
.. autoclass:: paddle.v2.layer.scale_shift
:noindex:
Sampling Layers Sampling Layers
=============== ===============
......
# Alalysis of large model distributed training in Paddle
***NOTE: This is only some note for how we implemeted this scheme in V1, not a new design.***
## What is it
We often encounter cases that the embedding layer parameters(sparse) are so large that we can not store it in the trainer's memory when training. So we need to put them to several servers, and fetch them row by row instead of fetch all of the parameters.
## How to use
Specify command-line argument like `--loadsave_parameters_in_pserver=true --ports_num_for_sparse=1 --use_old_updater=1` when starting the paddle trainer. And also add something like `--ports_num_for_sparse=1 --pserver_num_threads=5` when starting pserver processes.
Accrodingly, configure your embedding layers like:
```python
SPARSE_REMOTE=True
w1 = data_layer(name="w1", size=dict_size)
emb1 = embedding_layer(input=w1, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
w2 = data_layer(name="w2", size=dict_size)
emb2 = embedding_layer(input=w2, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
...
```
## Implementation details
```c++
enum MatType {
MAT_NORMAL,
MAT_NORMAL_SHARED,
MAT_VALUE_SHARED,
MAT_SPARSE_ROW_IDS,
MAT_SPARSE_ROW_AUTO_GROW,
MAT_CACHE_ROW,
MAT_SPARSE_ROW,
MAT_SPARSE_ROW_PREFETCH,
MAT_SPARSE_ROW_PREFETCH_FULL_SIZE,
};
```
`MAT_SPARSE_ROW_PREFETCH` is what we use when configured to fetch only row of matrix when training.
In `trainer_internal.cpp:L93 trainOneBatch`:
```c++
if (config_->getOptConfig().use_sparse_remote_updater()) {
REGISTER_TIMER("prefetch");
gradientMachine_->prefetch(inArgs);
parameterUpdater_->getParametersRemote();
}
```
When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver.
In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
```c++
if (fullSize) {
...
} else {
getParams = [&] {
parameterClient_->getParameterSparse(
/* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType);
};
applyL1 = [](Parameter& para, real decayRate) {
para.getMat(PARAMETER_VALUE)->applyL1(/*lr=*/1.0f, decayRate);
};
}
```
Calling `parameterClient_->getParameterSparse` will do remote call to pserver's `getParameterSparse`:
```c++
void ParameterServer2::getParameterSparse(const SendParameterRequest& request,
std::vector<Buffer>& inputBuffers,
SendParameterResponse* response,
std::vector<Buffer>* outputBuffers) {
(void)inputBuffers;
auto& buffer = *readWriteBuffer_;
size_t numReals = 0;
for (const auto& block : request.blocks()) {
numReals += getParameterConfig(block).dims(1);
}
buffer.resize(numReals);
VLOG(3) << "pserver: getParameterSparse, numReals=" << numReals;
ReadLockGuard guard(parameterMutex_);
size_t offset = 0;
for (const auto& block : request.blocks()) {
size_t width = getParameterConfig(block).dims(1);
Buffer buf = {buffer.data() + offset, width};
int type = request.send_back_parameter_type();
sendBackParameterSparse(block, type, response, &buf, width, outputBuffers);
offset += width;
}
}
```
`getParameterConfig(block).dims(1)` returns the width of the current "parameter block"(a shard of parameter object),
then `getParameterSparse` remote call returns only one row of data to the client.
...@@ -101,6 +101,7 @@ if use_mkldnn ...@@ -101,6 +101,7 @@ if use_mkldnn
5.**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue``mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。 5.**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue``mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。
6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。 6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。
7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。 7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。
8. 关于MKLDNN参数的保存。由于MKLDNN参数的格式与PaddlePaddle原有的格式存在不一样的情况,所以需要在保存参数时同时保存该格式信息。目前准备扩展[Header](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/Parameter.h#L247)里面的`int32_t version`。这个值不管是在v1还是在v2里面,一直保存的是0,所以可以充分利用这个信息,定义一个枚举处理所有MKLDNN的参数格式,从而`MKLDNNLayer`就可以从输入的参数中获取需要的格式信息。
## References ## References
......
...@@ -68,7 +68,7 @@ As a simple example, consider the following: ...@@ -68,7 +68,7 @@ As a simple example, consider the following:
1. **BLAS Dependencies(optional)** 1. **BLAS Dependencies(optional)**
CMake will search BLAS libraries from system. If not found, OpenBLAS will be downloaded, built and installed automatically. CMake will search BLAS libraries from the system. If not found, OpenBLAS will be downloaded, built and installed automatically.
To utilize preinstalled BLAS, you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`. To utilize preinstalled BLAS, you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`.
```bash ```bash
...@@ -131,9 +131,9 @@ As a simple example, consider the following: ...@@ -131,9 +131,9 @@ As a simple example, consider the following:
To build GPU version, you will need the following installed: To build GPU version, you will need the following installed:
1. a CUDA-capable GPU 1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain 2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment, The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on including the host compiler and C runtime libraries, and is therefore only supported on
...@@ -172,6 +172,7 @@ export PATH=<path to install>/bin:$PATH ...@@ -172,6 +172,7 @@ export PATH=<path to install>/bin:$PATH
# install PaddlePaddle Python modules. # install PaddlePaddle Python modules.
sudo pip install <path to install>/opt/paddle/share/wheels/*.whl sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
``` ```
## <span id="centos">Build on Centos 7</span> ## <span id="centos">Build on Centos 7</span>
### Install Dependencies ### Install Dependencies
...@@ -192,9 +193,9 @@ sudo pip install <path to install>/opt/paddle/share/wheels/*.whl ...@@ -192,9 +193,9 @@ sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
To build GPU version, you will need the following installed: To build GPU version, you will need the following installed:
1. a CUDA-capable GPU 1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain 2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment, The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on including the host compiler and C runtime libraries, and is therefore only supported on
...@@ -222,7 +223,7 @@ mkdir build && cd build ...@@ -222,7 +223,7 @@ mkdir build && cd build
``` ```
Finally, you can build and install PaddlePaddle: Finally, you can build and install PaddlePaddle:
```bash ```bash
# you can add build option here, such as: # you can add build option here, such as:
cmake3 .. -DCMAKE_INSTALL_PREFIX=<path to install> cmake3 .. -DCMAKE_INSTALL_PREFIX=<path to install>
......
...@@ -146,3 +146,19 @@ paddle_error paddle_gradient_machine_randomize_param( ...@@ -146,3 +146,19 @@ paddle_error paddle_gradient_machine_randomize_param(
m->machine->randParameters(); m->machine->randParameters();
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error paddle_gradient_machine_get_layer_output(
paddle_gradient_machine machine,
const char* layerName,
paddle_arguments args) {
auto m = cast(machine);
auto out = paddle::capi::cast<paddle::capi::CArguments>(args);
if (m == nullptr || layerName == nullptr || out == nullptr ||
m->machine == nullptr) {
return kPD_NULLPTR;
}
auto layerOutput = m->machine->getLayerOutput(layerName);
out->args.push_back(layerOutput);
return kPD_NO_ERROR;
}
...@@ -39,7 +39,11 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference( ...@@ -39,7 +39,11 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference(
/** /**
* @brief Create a gradient machine used for model inference, using config with * @brief Create a gradient machine used for model inference, using config with
* parameters which is generated by `paddle merge_model`. * parameters which is generated by `paddle merge_model`.
* @param [out] machine that used for model inference. * Example:
* paddle merge_model \
* --model_dir="pass-00000" \
* --model_file="merged_model.paddle"
* @param [out] machine that used for model inference
* @param [in] mergedModel * @param [in] mergedModel
* @param [in] size * @param [in] size
* @return paddle_error * @return paddle_error
...@@ -97,6 +101,18 @@ paddle_gradient_machine_randomize_param(paddle_gradient_machine machine); ...@@ -97,6 +101,18 @@ paddle_gradient_machine_randomize_param(paddle_gradient_machine machine);
PD_API paddle_error PD_API paddle_error
paddle_gradient_machine_destroy(paddle_gradient_machine machine); paddle_gradient_machine_destroy(paddle_gradient_machine machine);
/**
* @brief Get the output of the layer named `layerName`.
* @param [in] gradient machine that have run a inference
* @param [in] layerName name of specified layer
* @param [out] args output of the specified layer
* @return paddle_error
*/
PD_API paddle_error
paddle_gradient_machine_get_layer_output(paddle_gradient_machine machine,
const char* layerName,
paddle_arguments args);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc) ...@@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc) cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(attribute_proto SRCS attribute.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto)
proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(attribute SRCS attribute.cc DEPS op_desc op_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute) cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder) cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
...@@ -42,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -42,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
......
...@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() { ...@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return STRINGS; return STRINGS;
} }
Attribute GetAttrValue(const AttrDesc& attr_desc) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: { case paddle::framework::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
......
...@@ -20,8 +20,7 @@ limitations under the License. */ ...@@ -20,8 +20,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/framework/attribute.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h" #include "paddle/platform/variant.h"
...@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap; ...@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template <typename T> template <typename T>
AttrType AttrTypeID(); AttrType AttrTypeID();
Attribute GetAttrValue(const AttrDesc& attr_desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
// check whether a value(attribute) fit a certain limit // check whether a value(attribute) fit a certain limit
template <typename T> template <typename T>
......
...@@ -15,31 +15,44 @@ ...@@ -15,31 +15,44 @@
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <list> #include <list>
#include <memory>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static bool AllInSet(const std::vector<std::string>& names, template <typename Map, typename T>
const std::string& suffix, static void ForEachVarName(const Map& names, T callback) {
const std::unordered_set<std::string>& set) {
for (auto& name : names) { for (auto& name : names) {
if (set.find(name + suffix) == set.end()) { for (auto& n : name.second) {
return false; if (callback(n)) return;
} }
} }
return true;
} }
static std::shared_ptr<OperatorBase> NOP() { // return whether all the names + suffixes in the set
auto net_op = std::make_shared<operators::NetOp>(); static bool AllInSet(
net_op->type_ = "@NOP@"; const std::map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) {
bool all_in_set = true;
ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
all_in_set = set.find(n + suffix) != set.end();
return !all_in_set;
});
return all_in_set;
}
static std::unique_ptr<OperatorBase> NOP() {
auto net_op = new operators::NetOp();
net_op->SetType("@NOP@");
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return std::unique_ptr<OperatorBase>(net_op);
} }
// Get backward operator from a forward operator, recursively implementation. // Get backward operator from a forward operator, a recursive implementation.
// //
// no_grad_names the gradient variable names without gradient calculating. // no_grad_names the gradient variable names without gradient calculating.
// //
...@@ -47,122 +60,152 @@ static std::shared_ptr<OperatorBase> NOP() { ...@@ -47,122 +60,152 @@ static std::shared_ptr<OperatorBase> NOP() {
// BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and // BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and
// pass `uniq_id` through recursive calling. // pass `uniq_id` through recursive calling.
// //
// returns The backward operator. For simple situation, it is a simple // returns The backward operator. In a simple situation, it may be a simple
// operator. For complex situation, it is a NetOp. // operator, in a complex situation, it maybe a NetOp.
// //
// See Backward.h for details // See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive( static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate, // If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take // just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic. // too much time for calculation, but it is useful for simplifying logic.
if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) {
return NOP(); return NOP();
} }
// All output gradients of forwarding operator do not need to calculate. // All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into // Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP. // `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
for (auto& name : forwardOp.inputs_) { no_grad_names /*set*/)) {
// Mark all input is not need ForEachVarName(forwardOp.Inputs(),
no_grad_names.insert(name + kGradVarSuffix); [&no_grad_names](const std::string& name) -> bool {
} no_grad_names.insert(GradVarName(name));
return false;
});
return NOP(); return NOP();
} }
// Returned gradient network // Returned gradient network
auto net = std::make_shared<operators::NetOp>(); auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast. // Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp); auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
// Map from output gradient variable name to operator's indices in // Map from output gradient variable name to operator's indices in
// backward net. That operator generates that variable. // backward net's ops_. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0; size_t local_op_id = 0;
// reversely travel forwardNet // reversely travel forwardNet and collect all duplicate outputs.
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) { ++it, ++local_op_id) {
auto fwd = *it; auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); ForEachVarName(bwd->Outputs(),
for (auto& out : bwd->outputs_) { [&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
} return false;
});
net->AppendOp(std::move(bwd));
} }
// Get unique ID for this method. // Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
// TODO(dzh): more comment // TODO(dzh): more comment
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; // multiple operators which have the same output (y for example) may
// overwrite the same y variable when backward, special operations are token
// to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y.
using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
std::list<Pos> insert_position; std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second; auto& dup_op = dup_output_op.second;
// no duplicate output
if (dup_op.size() == 1) continue; if (dup_op.size() == 1) continue;
std::vector<std::string> dup_outputs;
// process the duplicate outputs
std::vector<std::string> dup_outputs;
for (size_t i = 0; i < dup_op.size(); ++i) { for (size_t i = 0; i < dup_op.size(); ++i) {
// rename each duplicate output to an alias
auto op_offset = dup_op[i]; auto op_offset = dup_op[i];
dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
std::to_string(i)); std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
// collect all the offset to append `add` op for each alias
insert_position.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
OpRegistry::CreateOp( {{"Out", {name}}}, {})});
"add", {dup_outputs}, {name},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
} }
// make sure the inserted `add` ops follow the BFS order.
insert_position.sort( insert_position.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; }); [](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_position) { for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second); net->InsertOp(pos.first + 1, std::move(pos.second));
} }
} else { } else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
for (std::string& grad_input : grad_op->inputs_) {
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
if (no_grad_names.count(grad_input)) { if (no_grad_names.count(grad_input)) {
// +1 for \0 // +1 for \0
std::string prefix = grad_input.substr( std::string prefix = grad_input.substr(
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
grad_input = prefix + kZeroVarSuffix; grad_op->Rename(grad_input, prefix + kZeroVarSuffix);
// If part of input gradient of that operator is not calculated, fill // If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient. // zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, net->AppendOp(OpRegistry::CreateOp("fill_zeros_like",
{grad_input}, {})); {{"Src", {prefix}}},
} {{"Dst", {grad_input}}}, {}));
}
for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
} }
return false;
});
ForEachVarName(grad_op->Outputs(),
[&no_grad_names, &grad_op](const std::string& grad_output) {
if (no_grad_names.count(grad_output)) {
grad_op->Rename(grad_output, kEmptyVarName);
}
return false;
});
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent_op") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
rnn_grad_op->set_stepnet(
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
} }
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op; return grad_op;
} }
net->AddOp(grad_op); net->AppendOp(std::move(grad_op));
} }
net->type_ = "@GENERATED_BACKWARD@"; net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp(); net->CompleteAddOp();
return net; return std::unique_ptr<OperatorBase>(
static_cast<OperatorBase*>(net.release()));
} }
// See header for comments // See header for comments
std::shared_ptr<OperatorBase> Backward( std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names; std::unordered_set<std::string> no_grad_names;
......
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
// Create the backward operator from a forward operator. // Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment. // TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward( extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
......
...@@ -28,21 +28,13 @@ using OpAttrChecker = framework::OpAttrChecker; ...@@ -28,21 +28,13 @@ using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope; using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext; using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase)
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient(); AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").IgnoreGradient(); AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").IgnoreGradient(); AddOutput("Out", "Out of Add").NotInGradient();
AddComment("Add Op"); AddComment("Add Op");
} }
}; };
...@@ -51,8 +43,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker { ...@@ -51,8 +43,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
public: public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("A", "A"); AddInput("X", "A");
AddInput("B", "B"); AddInput("Y", "B");
AddOutput("Out", "Out"); AddOutput("Out", "Out");
AddComment("Mul"); AddComment("Mul");
} }
...@@ -63,7 +55,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { ...@@ -63,7 +55,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X"); AddInput("X", "X");
AddOutput("Y", "Y"); AddOutput("Out", "Y");
AddComment("Sigmoid"); AddComment("Sigmoid");
} }
}; };
...@@ -73,21 +65,25 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { ...@@ -73,21 +65,25 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker) NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X input"); AddInput("X", "X input");
AddOutput("Y", "Y output"); AddOutput("Out", "Y output");
AddComment("NoGradOp, same input output. no Grad"); AddComment("NoGradOp, same input output. no Grad");
} }
}; };
class FcOp : public operators::NetOp { class FcOp : public operators::NetOp {
public: public:
void Init() override { FcOp(const std::string &type, const VarNameMap &inputs,
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, const VarNameMap &outputs, const AttributeMap &attrs)
{Output("mul_result")}, {})); : NetOp(type, inputs, outputs, attrs) {
auto b_name = Input("b"); AppendOp(OpRegistry::CreateOp("mul",
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
{{"Out", {Output("mul_result")}}}, {}));
auto input_b = Inputs("b");
std::string before_act = "mul_result"; std::string before_act = "mul_result";
if (b_name != kEmptyVarName) { if (input_b.size() != 0) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, AppendOp(OpRegistry::CreateOp(
{Output("add_result")}, {})); "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
{{"Out", {Output("add_result")}}}, {}));
before_act = "add_result"; before_act = "add_result";
} else { } else {
auto out_varname = Output("add_result"); auto out_varname = Output("add_result");
...@@ -96,8 +92,8 @@ class FcOp : public operators::NetOp { ...@@ -96,8 +92,8 @@ class FcOp : public operators::NetOp {
} }
} }
AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")}, AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
{})); {{"Out", {Output("Out")}}}, {}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
...@@ -109,8 +105,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { ...@@ -109,8 +105,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x"); AddInput("X", "x");
AddInput("W", "w"); AddInput("W", "w");
AddInput("b", "b"); AddInput("b", "b");
AddOutput("mul_result", "").SetTemporary(); AddOutput("mul_result", "").AsIntermediate();
AddOutput("add_result", "").SetTemporary(); AddOutput("add_result", "").AsIntermediate();
AddOutput("Out", ""); AddOutput("Out", "");
AddComment(""); AddComment("");
} }
...@@ -141,7 +137,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -141,7 +137,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public: public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetMultiple(); AddInput("X", "x").AsDuplicable();
AddOutput("Y", "y"); AddOutput("Y", "y");
AddComment(""); AddComment("");
} }
...@@ -152,51 +148,48 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -152,51 +148,48 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); f::NOP);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); f::NOP);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(4UL, gop->inputs_.size()); ASSERT_EQ(1UL, gop->Inputs().size());
ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->Type());
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
ASSERT_EQ(f::GradVarName("X"), gop->outputs_[0]); ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
ASSERT_EQ(f::GradVarName("b"), gop->outputs_[1]);
ASSERT_EQ(f::GradVarName("X"), gop->Output(f::GradVarName("X")));
} }
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"X"}); auto gop = f::Backward(*fwd, {"x"});
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), ASSERT_EQ(gop->Output(f::GradVarName("X")), f::kEmptyVarName);
f::GradVarName("X")),
gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"}); auto no_input_gop = f::Backward(*fwd, {"x", "b"});
ASSERT_NE(no_input_gop, nullptr); ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp()); ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL, ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
} }
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd =
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
{{"mul_result", {"mul_res"}},
{"add_result", {"add_re"}},
{"Out", {"out"}}},
{});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -207,19 +200,22 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -207,19 +200,22 @@ TEST(Backward, net_fc_backward_normal) {
ASSERT_EQ(3UL, net->ops_.size()); ASSERT_EQ(3UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0]; f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
f::OperatorBase &d_add = *net->ops_[1]; f::OperatorBase &d_add = *net->ops_[1];
ASSERT_EQ("rowwise_add_grad", d_add.type_); ASSERT_EQ("rowwise_add_grad", d_add.Type());
f::OperatorBase &d_mul = *net->ops_[2]; f::OperatorBase &d_mul = *net->ops_[2];
ASSERT_EQ("mul_grad", d_mul.type_); ASSERT_EQ("mul_grad", d_mul.Type());
} }
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName}, f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {}}},
{"mul_result", "add_result", "tmp"}, {}); {{"mul_result", {"mul_res"}},
{"add_result", {"add_res"}},
{"Out", {"tmp"}}},
{});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -230,96 +226,113 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -230,96 +226,113 @@ TEST(Backward, net_fc_backward_not_have_b) {
ASSERT_EQ(2UL, net->ops_.size()); ASSERT_EQ(2UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0]; f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); ASSERT_EQ("sigmoid_grad", d_sigmoid.Type());
f::OperatorBase &d_mul = *net->ops_[1]; f::OperatorBase &d_mul = *net->ops_[1];
ASSERT_EQ("mul_grad", d_mul.type_); ASSERT_EQ("mul_grad", d_mul.Type());
} }
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
ops::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, net.AppendOp(f::OpRegistry::CreateOp(
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {{"mul_result", {"mul_tmp_0"}},
{"mul_tmp_1", "add_tmp_1", "hidden1"}, {})); {"add_result", {"add_tmp_0"}},
{"Out", {"hidden0"}}},
{}));
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_tmp_1"}},
{"add_result", {"add_tmp_1"}},
{"Out", {"hidden1"}}},
{}));
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need. auto bwd = Backward(net, {"x"}); // x@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
std::unordered_set<std::string> all_output = std::unordered_set<std::string>( auto output_vars = bwd_net->OutputVars(true);
bwd_net->outputs_.begin(), bwd_net->outputs_.end()); std::unordered_set<std::string> all_outputs =
all_output.erase(f::kEmptyVarName); std::unordered_set<std::string>(output_vars.begin(), output_vars.end());
all_outputs.erase(f::kEmptyVarName);
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(f::GradVarName(out)), all_output.end()); ASSERT_NE(all_outputs.find(f::GradVarName(out)), all_outputs.end());
} }
// Not Generated X // Not Generated X
ASSERT_EQ(all_output.find(f::GradVarName("X")), all_output.end()); ASSERT_EQ(all_outputs.find(f::GradVarName("X")), all_outputs.end());
ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(f::kEmptyVarName, ASSERT_EQ(f::kEmptyVarName,
first_fc_grad->ops_[2]->Output(f::GradVarName("A"))); first_fc_grad->ops_[2]->Output(f::GradVarName("X")));
} }
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
ops::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); {{"Out", {"out"}}}, {}));
net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
{{"Out", {"FinalOut"}}}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = f::Backward(net, {}); auto bwd = f::Backward(net, {});
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->type_); ASSERT_EQ("add", bwd_net->ops_[2]->Type());
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
auto fwd = f::OpRegistry::CreateOp( auto fwd =
"fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"}, f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
{{"temporary_index", std::vector<int>{0, 1}}}); {{"mul_result", {"mul_out"}},
{"add_result", {"add_out"}},
{"Out", {"out1"}}},
{{"temporary_index", std::vector<int>{0, 1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
} }
TEST(Backward, op_all_input_are_not_need) { TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp(
auto backward = f::Backward(*fwd, {"X", "b"}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
auto backward = f::Backward(*fwd, {"x", "b"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<ops::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty()); ASSERT_TRUE(net->ops_.empty());
} }
TEST(Backward, op_all_output_are_not_need) { TEST(Backward, op_all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp(
auto backward = f::Backward(*fwd, {"Out"}); "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
auto backward = f::Backward(*fwd, {"out"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<ops::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty()); ASSERT_TRUE(net->ops_.empty());
} }
TEST(Backward, op_part_of_output_are_not_need) { TEST(Backward, op_part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto fwd = f::OpRegistry::CreateOp("many_output_op", {{"x", {"X"}}},
{{"y", {"Y"}}, {"z", {"Z"}}}, {});
auto backward = f::Backward(*fwd, {"Z"}); auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<ops::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2UL); ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0]; auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_); ASSERT_EQ("fill_zeros_like", fill_zero.Type());
ASSERT_EQ(1UL, fill_zero.inputs_.size()); ASSERT_EQ(1UL, fill_zero.Inputs("Src").size());
ASSERT_EQ("Z", fill_zero.inputs_[0]); ASSERT_EQ("Z", fill_zero.Input("Src"));
ASSERT_EQ(1UL, fill_zero.outputs_.size()); ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size());
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.outputs_[0]); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst"));
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ("many_output_op_grad", d_many_out.Type());
ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size()); // I/O/OG
ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
d_many_out.Input(f::GradVarName("z"))); d_many_out.Input(f::GradVarName("z")));
ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
...@@ -327,44 +340,62 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -327,44 +340,62 @@ TEST(Backward, op_part_of_output_are_not_need) {
} }
TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto fwd = f::OpRegistry::CreateOp("mul", {{"X", {"a"}}, {"Y", {"b"}}},
{{"Out", {"out"}}}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
auto &grad_mul = *backward; auto &grad_mul = *backward;
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.Type(), "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.Outputs().size(), 2UL);
ASSERT_EQ(grad_mul.Output(f::GradVarName("A")), f::kEmptyVarName); ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName);
ASSERT_EQ(grad_mul.Output(f::GradVarName("B")), f::GradVarName("b")); ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b"));
ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
ASSERT_EQ(grad_mul.Input("A"), "a"); ASSERT_EQ(grad_mul.Input("X"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b"); ASSERT_EQ(grad_mul.Input("Y"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out"); ASSERT_EQ(grad_mul.Input("Out"), "out");
} }
TEST(Backward, linear_net_intermediate_variable_has_no_grad) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ops::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, net.AppendOp(f::OpRegistry::CreateOp(
{"mul_out1", "add_out1", "out1"}, {})); "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {{"mul_result", {"mul_out1"}},
{"mul_out2", "tmp_out2", "out2"}, {})); {"add_result", {"add_out1"}},
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"Out", {"out1"}}},
{"mul_out3", "tmp_out3", "out3"}, {})); {}));
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
{{"mul_result", {"mul_out2"}},
{"add_result", {"tmp_out2"}},
{"Out", {"out2"}}},
{}));
net.AppendOp(f::OpRegistry::CreateOp(
"fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
{{"mul_result", {"mul_out3"}},
{"add_result", {"tmp_out3"}},
{"Out", {"out3"}}},
{}));
net.CompleteAddOp(); net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<ops::NetOp *>(backward.get()); auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL); ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(),
3UL /* external input number */ const char *all = paddle::operators::NetOp::kAll;
EXPECT_EQ(grad_fc.Inputs(all).size(),
2UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2U /* internal variable number*/);
EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ EXPECT_EQ(grad_fc.Outputs(all).size(),
+ 2UL /* input number of rowwise_add */ 2UL /* input number of mul*/
+ 1UL /* input number of sigmod */); + 2UL /* input number of rowwise_add
EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); */
EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); + 1UL /* input number of sigmod */);
EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
} }
...@@ -283,6 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { ...@@ -283,6 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list<int> init_list) { DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list); *this = make_ddim(init_list);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,9 +15,6 @@ limitations under the License. */ ...@@ -15,9 +15,6 @@ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
package paddle.framework; package paddle.framework;
// Attribute Type for paddle's Op.
// Op contains many attributes. Each type of attributes could be different.
// The AttrType will be shared between AttrDesc and AttrProto.
enum AttrType { enum AttrType {
INT = 0; INT = 0;
FLOAT = 1; FLOAT = 1;
...@@ -25,4 +22,61 @@ enum AttrType { ...@@ -25,4 +22,61 @@ enum AttrType {
INTS = 3; INTS = 3;
FLOATS = 4; FLOATS = 4;
STRINGS = 5; STRINGS = 5;
} }
\ No newline at end of file
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
message Attr {
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional float f = 4;
optional string s = 5;
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
};
message Var {
required string parameter = 1;
repeated string arguments = 2;
};
required string type = 3;
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
};
// OpProto describes a C++ framework::OperatorBase derived class.
message OpProto {
// VarProto describes the C++ type framework::Variable.
message Var {
required string name = 1;
required string comment = 2;
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
}
// AttrProto describes the C++ type Attribute.
message Attr {
required string name = 1;
required AttrType type = 2;
required string comment = 3;
// If that attribute is generated, it means the Paddle third
// language binding has responsibility to fill that
// attribute. End-User should not set that attribute.
optional bool generated = 4 [ default = false ];
}
required string type = 1;
repeated Var inputs = 2;
repeated Var outputs = 3;
repeated Attr attrs = 4;
required string comment = 5;
}
...@@ -13,105 +13,53 @@ express or implied. See the License for the specific language governing ...@@ -13,105 +13,53 @@ express or implied. See the License for the specific language governing
permissions and limitations under the License. */ permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef std::vector<int> Ints;
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
const Ints* AttrFormat(const AttributeMap& attrs, const std::string& key) { static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
return (attrs.count(key) > 0) ? &boost::get<Ints>(attrs.at(key)) : nullptr; bool is_grad, OperatorBase::VarNameMap* vars) {
} const auto& src_inout =
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
Ints* AttrFormat(AttributeMap& attrs, const std::string& key) { auto& dst_inout = *vars;
return (attrs.count(key) > 0) ? &boost::get<Ints>(attrs.at(key)) : nullptr; const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_;
}
static void TransOpArg(const OperatorBase* src_op,
std::vector<std::string>& grad_inputs,
std::vector<std::string>& grad_outputs,
AttributeMap& grad_attrs,
std::unordered_map<std::string, int>& grad_idxs,
const std::string& src_type, const std::string& dst_type,
int& idx, bool is_grad) {
const std::vector<std::string>& src_inout =
(src_type == "input_format") ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = AttrFormat(src_op->Attrs(), src_type);
std::vector<std::string>& dst_inout =
(dst_type == "input_format") ? grad_inputs : grad_outputs;
std::vector<int>* dst_format = AttrFormat(grad_attrs, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list = const auto& src_arg_list =
(src_type == "input_format") ? proto.inputs() : proto.outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
std::string src_name = arg.name(); if (arg.not_in_gradient() && !is_grad) continue;
std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name; const std::string src_name = arg.name();
grad_idxs[dst_name] = idx++; std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
int src_arg_idx = src_op->in_out_idxs_->at(src_name); dst_inout[dst_name].reserve(src_inout.at(src_name).size());
int src_begin = for (auto& var_name : src_inout.at(src_name)) {
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); std::string s = is_grad ? GradVarName(var_name) : var_name;
int src_end = src_format == nullptr ? src_arg_idx + 1 dst_inout[dst_name].emplace_back(s);
: src_format->at(src_arg_idx + 1);
for (int i = src_begin; i < src_end; ++i) {
std::string s =
is_grad ? src_inout[i] + kGradVarSuffix
: (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]);
dst_inout.emplace_back(s);
}
if (dst_format != nullptr) {
dst_format->push_back(dst_inout.size());
} }
} }
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
const std::string& grad_op_type = OpRegistry::grad_ops().at(op->Type()); auto it = OpRegistry::op_info_map().find(op->Type());
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
AttributeMap grad_attrs(op->Attrs()); "'%s' has not been registered.", op->Type());
grad_attrs.erase("input_format"); PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
grad_attrs.erase("output_format"); op->Type());
if (op->Attrs().count("input_format") > 0) { std::string grad_op_type = it->second.grad_op_type_;
grad_attrs["output_format"] = std::vector<int>({0}); PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
} op->Type());
if (op->Attrs().count("input_format") > 0 ||
op->Attrs().count("output_format") > 0) { OperatorBase::VarNameMap inputs;
grad_attrs["input_format"] = std::vector<int>({0}); OperatorBase::VarNameMap outputs;
} TransOpArg(op, OpArgType::IN, false, &inputs); // I
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
std::vector<std::string> grad_inputs, grad_outputs; TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, OpArgType::IN, true, &outputs); // IG
using VarIndexMap = std::unordered_map<std::string, int>;
VarIndexMap* grad_idxs = new VarIndexMap; it = OpRegistry::op_info_map().find(grad_op_type);
int in_idx = 0; PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
int out_idx = 0; "'%s' has not been registered.", grad_op_type);
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs());
"input_format", "input_format", in_idx, false); // I
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, false); // G
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, true); // OG
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"input_format", "output_format", out_idx, true); // IG
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->inputs_ = grad_inputs;
grad_op->outputs_ = grad_outputs;
grad_op->attrs_ = grad_attrs;
grad_op->in_out_idxs_.reset(grad_idxs);
return grad_op;
} }
} // namespace framework } // namespace framework
......
...@@ -8,24 +8,15 @@ USE_OP(add_two); ...@@ -8,24 +8,15 @@ USE_OP(add_two);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase)
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class MutiInOutOpMaker : public OpProtoAndCheckerMaker { class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public: public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input"); AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetMultiple(); AddInput("In2_mult", "a multiple input").AsDuplicable();
AddInput("In3", "another single input"); AddInput("In3", "another single input");
AddOutput("Out1", "a single output"); AddOutput("Out1", "a single output");
AddOutput("Out2_mult", "a multiple output").SetMultiple(); AddOutput("Out2_mult", "a multiple output").AsDuplicable();
AddComment("test op with multiple inputs and outputs"); AddComment("test op with multiple inputs and outputs");
} }
}; };
...@@ -35,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { ...@@ -35,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("In1", "a single input"); AddInput("In1", "a single input");
AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient(); AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient();
AddInput("In3_mult", "another multiple input").SetMultiple(); AddInput("In3_mult", "another multiple input").AsDuplicable();
AddOutput("Out1_mult", "a multiple output").SetMultiple(); AddOutput("Out1_mult", "a multiple output").AsDuplicable();
AddOutput("Out2", "a single output").IgnoreGradient(); AddOutput("Out2", "a single output").NotInGradient();
AddComment("op with inputs and outputs ignored in gradient calculating"); AddComment("op with inputs and outputs ignored in gradient calculating");
} }
}; };
...@@ -49,35 +40,33 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { ...@@ -49,35 +40,33 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
TEST(GradOpBuilder, AddTwo) { TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<f::OperatorBase> add_op( std::shared_ptr<f::OperatorBase> add_op(f::OpRegistry::CreateOp(
f::OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_add_op = std::shared_ptr<f::OperatorBase> grad_add_op =
f::OpRegistry::CreateGradOp(*add_op); f::OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4); EXPECT_EQ(grad_add_op->Inputs().size(), 4UL);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Outputs().size(), 2UL);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out"); EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD"); EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out"));
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD"); EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x"));
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
} }
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) { TEST(GradOpBuilder, MutiInOut) {
f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 4, 5}},
{"output_format", std::vector<int>{0, 1, 3}}};
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
"mult_io", {"in1", "in2_1", "in2_2", "in2_3", "in3"}, "mult_io", {{"In1", {"in1"}},
{"out1", "out2_1", "out2_2"}, attrs)); {"In2_mult", {"in2_1", "in2_2", "in2_3"}},
{"In3", {"in3"}}},
{{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_test_op = std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"), EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"})); std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
...@@ -91,7 +80,7 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -91,7 +80,7 @@ TEST(GradOpBuilder, MutiInOut) {
std::vector<std::string>( std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")})); {f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"), std::vector<std::string>({f::GradVarName("in2_1"),
...@@ -101,31 +90,28 @@ TEST(GradOpBuilder, MutiInOut) { ...@@ -101,31 +90,28 @@ TEST(GradOpBuilder, MutiInOut) {
} }
TEST(GradOpBuilder, IOIgnoredInGradient) { TEST(GradOpBuilder, IOIgnoredInGradient) {
f::AttributeMap attrs{{"input_format", std::vector<int>{0, 1, 3, 5}},
{"output_format", std::vector<int>{0, 2, 3}}};
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
"io_ignored", {"in1", "in2_1", "in2_2", "in3_1", "in3_2"}, "io_ignored", {{"In1", {"in1"}},
{"out1_1", "out1_2", "out2"}, attrs)); {"In2_mult", {"in2_1", "in2_2"}},
{"In3_mult", {"in3_1", "in3_2"}}},
{{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
std::shared_ptr<f::OperatorBase> grad_test_op = std::shared_ptr<f::OperatorBase> grad_test_op =
f::OpRegistry::CreateGradOp(*test_op); f::OpRegistry::CreateGradOp(*test_op);
// 'In2' and 'Out2' are ignored in gradient calculating // 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Input("In1"), "in1");
EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
std::vector<std::string>({f::kEmptyVarName, f::kEmptyVarName}));
EXPECT_EQ(grad_test_op->Inputs("In3_mult"), EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"})); std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"})); std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName);
EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")),
std::vector<std::string>( std::vector<std::string>(
{f::GradVarName("out1_1"), f::GradVarName("out1_2")})); {f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
f::GradVarName("out2")); f::GradVarName("out2"));
ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
std::vector<std::string>( std::vector<std::string>(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Protocol Message for 3rd-party language binding.
//
// Paddle Python package will use `OpProto` to generate op creation methods.
// The op creation methods take user's input and generate `OpDesc` proto
// message,
// then pass `OpDesc` to C++ side and create Op pointer.
//
syntax = "proto2";
package paddle.framework;
import "attribute.proto";
// Attribute protocol message for 3rd-party language binding.
// It will store the Op support what attribute and what type.
message AttrProto {
// Supported attribute name. e.g. `scale` for cosine op.
required string name = 1;
// Supported attribute type.
required AttrType type = 2;
// Supported attribute comments. It helps 3rd-party language generate
// doc-string.
required string comment = 3;
// If that attribute is generated, it means the Paddle third language
// binding has responsibility to fill that attribute. End-User should
// not set that attribute.
optional bool generated = 4 [ default = false ];
}
// Input or output message for 3rd-party language binding.
// It contains parameter name and its comments.
message VarProto {
// Input or output name in that op creation function.
// e.g. `cos(a, b, output, ...)`, "a", "b", "output" are names.
required string name = 1;
// The comment for that input. It helps 3rd-party language generate
// doc-string.
required string comment = 2;
// Is that input/output could be a list or not.
// If so, that Op should write a attributed named `input_format` or
// `output_format`.
//
// e.g.
// If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W`
// could be multiple, so the multiple of `X` and `W` is True, and OpDesc
// will hold a attribute of them.
//
// The Op desc of same fc could be
// {
// "type": "fc",
// "input": ["X1", "X2", "W1", "W2", "b"],
// "output": "fc.out",
// "attrs" : {
// "input_format": [0, 2, 4, 5]
// }
// }
//
optional bool multiple = 3 [ default = false ];
// It marks that output is a temporary output. That output is not used by
// user, but used by other op internally as input. If other op is not use
// that output, it could be optimized early.
//
// Attribute temporary_index will be set in OpDesc if there is some
// outputs are temporary.
//
// output = [ "xxx.out1", "xxx.tmp", "xxx.out2"],
// attrs = {
// "temporary_index": [1]
// }
optional bool temporary = 4 [ default = false ];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional bool ignore_gradient = 6;
}
// Op protocol message for 3rd-party language binding.
// It contains all information for generating op creation method.
message OpProto {
// The input information to generate op creation method.
repeated VarProto inputs = 1;
// The output information to generate op creation method.
repeated VarProto outputs = 2;
// The attribute information to generate op creation method.
repeated AttrProto attrs = 3;
// The comments for that Op. It helps 3rd-party language generate
// doc-string. The whole documentation of that Op is generated by comment,
// inputs, outputs, attrs together.
required string comment = 4;
// The type of that Op.
required string type = 5;
}
#include <gtest/gtest.h>
#include <paddle/framework/op_proto.pb.h>
TEST(TestOpProto, ALL) {
paddle::framework::OpProto proto;
{
auto ipt = proto.mutable_inputs()->Add();
*ipt->mutable_name() = "a";
*ipt->mutable_comment() = "the one input of cosine op";
}
{
auto ipt = proto.mutable_inputs()->Add();
*ipt->mutable_name() = "b";
*ipt->mutable_comment() = "the other input of cosine op";
}
{
auto opt = proto.mutable_outputs()->Add();
*opt->mutable_name() = "output";
*opt->mutable_comment() = "the output of cosine op";
}
{
auto attr = proto.mutable_attrs()->Add();
*attr->mutable_name() = "scale";
attr->set_type(paddle::framework::AttrType::FLOAT);
*attr->mutable_comment() = "the scale attribute of cosine op";
}
proto.set_type("cos");
*proto.mutable_comment() = "cosine op, output = scale * cos(a, b)";
ASSERT_TRUE(proto.IsInitialized());
}
\ No newline at end of file
...@@ -17,5 +17,48 @@ limitations under the License. */ ...@@ -17,5 +17,48 @@ limitations under the License. */
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(),
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op);
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
}
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
VarNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
auto& var_names_in_proto = var.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
return ret_val;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops");
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
}
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,320 +17,104 @@ limitations under the License. */ ...@@ -17,320 +17,104 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <type_traits> #include <type_traits>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
protected:
struct VariableBuilder {
VarProto* var_;
std::function<void()> on_multiple_;
std::function<void()> on_temporary_;
VariableBuilder& SetMultiple() {
var_->set_multiple(true);
on_multiple_();
return *this;
}
VariableBuilder& SetTemporary() {
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
var_->set_temporary(true);
on_temporary_();
return *this;
}
VariableBuilder& IgnoreGradient() {
var_->set_ignore_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name,
const std::string& comment) {
VarProto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
nullptr};
}
VariableBuilder AddOutput(const std::string& name,
const std::string& comment) {
VarProto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
[=] { this->SetHasTemporaryOutput(); }};
}
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
AttrProto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void SetHasMultiple(const std::string& in_out, bool* flag) {
if (!*flag) {
AddAttr<std::vector<int>>(in_out + "_format",
"The multiple index of " + in_out +
"\n"
R"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.
e.g.
input = ["a", "b", "c", "d", "e", "f"]
input_format = [0, 4, 5, 6]
means
The number of all input variables this op is six, and they are segmented into
three inputs.
The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC",
/*generated*/ true);
*flag = true;
}
}
void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
void SetHasMultipleOutput() {
SetHasMultiple("output", &has_multiple_output_);
}
void SetHasTemporaryOutput() {
if (!has_temporary_output_) {
AddAttr<std::vector<int>>("temporary_index",
R"DOC(The temporary index of output.
Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.
Add a mark to which output is temporary is helpful for future optimization.
)DOC",
/*generated*/ true)
.SetDefault(std::vector<int>());
has_temporary_output_ = true;
}
}
void CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
bool has_multiple_input_{false};
bool has_multiple_output_{false};
bool has_temporary_output_{false};
};
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>; using VarNameMap = OperatorBase::VarNameMap;
using VarIndexMap = std::unordered_map<std::string, int>; using OpCreator = std::function<OperatorBase*(
using VarNameList = std::vector<std::string>; const std::string& /*type*/, const VarNameMap& /*inputs*/,
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public: public:
template <typename OpType, typename ProtoMakerType> struct OpInfo {
static void RegisterOp(const std::string& op_type) { OpCreator creator_;
op_creators()[op_type] = [] { return new OpType; }; std::string grad_op_type_;
OpAttrChecker& op_checker = op_checkers()[op_type]; OpProto* proto_;
OpProto& op_proto = protos()[op_type]; OpAttrChecker* checker_;
auto maker = ProtoMakerType(&op_proto, &op_checker); };
maker.Validate();
op_proto.set_type(op_type);
PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
VarIndexMaps()[op_type].reset(new VarIndexMap());
auto& varmap = *VarIndexMaps()[op_type];
int idx = 0;
for (auto& var : op_proto.inputs()) {
varmap[var.name()] = idx++;
}
idx = 0;
for (auto& var : op_proto.outputs()) {
varmap[var.name()] = idx++;
}
}
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [] { return new GradOpType; };
grad_ops()[op_type] = grad_op_type;
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type);
auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);
GenerateTempVariableName(op);
{ template <typename OpType, typename ProtoMakerType, typename GradOpType>
auto var_index_it = VarIndexMaps().find(type); static void RegisterOp(const std::string& op_type,
if (var_index_it != VarIndexMaps().end()) { const std::string& grad_op_type) {
op->in_out_idxs_ = var_index_it->second; PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
} "'%s' is registered more than once.", op_type);
OpInfo op_info;
op_info.creator_ = [](const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs);
};
op_info.grad_op_type_ = grad_op_type;
if (std::type_index(typeid(ProtoMakerType)) !=
std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
op_info.proto_->set_type(op_type);
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
} }
op_info_map().insert(std::make_pair(op_type, op_info));
op->Init(); // register gradient op
return std::shared_ptr<OperatorBase>(op); if (!grad_op_type.empty()) {
} RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(inputs));
std::vector<std::string> outputs;
outputs.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(outputs));
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
} }
return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) { static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
PADDLE_ENFORCE(!op.IsNetOp(), const VarNameMap& inputs,
"Use framework::Backward to get backward ops"); const VarNameMap& outputs,
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op)); AttributeMap attrs);
grad_op->Init();
return grad_op;
}
static std::unordered_map<std::string, OpProto>& protos() { static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unordered_map<std::string, OpProto> protos_;
return protos_;
}
static std::unordered_map<std::string, std::string>& grad_ops() { static VarNameMap ConvertOpDescVarsToVarNameMap(
static std::unordered_map<std::string, std::string> grad_ops_; const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars);
return grad_ops_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_;
}
static std::unordered_map<std::string, OpCreator>& op_creators() { static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, OpCreator> op_creators_; static std::unordered_map<std::string, const OpInfo> op_info_map_;
return op_creators_; return op_info_map_;
}
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
}
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
if (outname == kTempVarName) {
outname += op->type_;
outname += "@";
outname += std::to_string(gUniqId.fetch_add(1));
}
}
} }
}; };
class Registrar { class Registrar {
public: public:
// In our design, various kinds of classes, e.g., operators and kernels, have // In our design, various kinds of classes, e.g., operators and kernels,
// their corresponding registry and registrar. The action of registration is // have their corresponding registry and registrar. The action of
// in the constructor of a global registrar variable, which, however, are not // registration is in the constructor of a global registrar variable, which,
// used in the code that calls package framework, and would be removed from // however, are not used in the code that calls package framework, and would
// the generated binary file by the linker. To avoid such removal, we add // be removed from the generated binary file by the linker. To avoid such
// Touch to all registrar classes and make USE_OP macros to call this // removal, we add Touch to all registrar classes and make USE_OP macros to
// method. So, as long as the callee code calls USE_OP, the global // call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker. // registrar variable won't be removed by the linker.
void Touch() {} void Touch() {}
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
explicit OpRegistrar(const char* op_type) { explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistrar(const char* op_type, const char* grad_op_type) {
} OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
}; grad_op_type);
template <typename GradOpType>
class GradOpRegistrar : public Registrar {
public:
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
...@@ -356,30 +140,30 @@ class OpKernelRegistrar : public Registrar { ...@@ -356,30 +140,30 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ class _OpClass_##op_type##_ : public op_class { \
__op_registrar_##op_type##__(#op_type); \ public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
class _OpGradClass_##op_type##_ : public grad_op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
return 0; \ return 0; \
} }
/** #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
* Macro to register Gradient Operator. REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
*/
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
...@@ -395,14 +179,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -395,14 +179,6 @@ class OpKernelRegistrar : public Registrar {
return 0; \ return 0; \
} }
/**
* Macro to Forbid user register Gradient Operator.
*/
#define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ #define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
...@@ -410,7 +186,8 @@ class OpKernelRegistrar : public Registrar { ...@@ -410,7 +186,8 @@ class OpKernelRegistrar : public Registrar {
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel we will use and tell the compiler to * Macro to mark what Operator and Kernel
* we will use and tell the compiler to
* link them into target. * link them into target.
*/ */
#define USE_OP_ITSELF(op_type) \ #define USE_OP_ITSELF(op_type) \
...@@ -421,23 +198,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -421,23 +198,6 @@ class OpKernelRegistrar : public Registrar {
static int use_op_itself_##op_type##_ __attribute__((unused)) = \ static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type() TouchOpRegistrar_##op_type()
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ #define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
...@@ -447,7 +207,8 @@ class OpKernelRegistrar : public Registrar { ...@@ -447,7 +207,8 @@ class OpKernelRegistrar : public Registrar {
__attribute__((unused)) = \ __attribute__((unused)) = \
TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
// TODO(fengjiayi): The following macros seems ugly, do we have better method? // TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
#ifdef PADDLE_ONLY_CPU #ifdef PADDLE_ONLY_CPU
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
...@@ -457,18 +218,13 @@ class OpKernelRegistrar : public Registrar { ...@@ -457,18 +218,13 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif #endif
#define USE_NO_GRAD_OP(op_type) \ #define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU);
#define USE_CPU_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \ USE_OP_KERNEL(op_type)
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \
USE_OP_GRADIENT(op_type)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -7,8 +7,7 @@ namespace paddle { ...@@ -7,8 +7,7 @@ namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorBase { class CosineOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase) using OperatorBase::OperatorBase;
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
...@@ -29,8 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -29,8 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase { class MyTestOp : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase) using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
...@@ -40,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -40,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "input of cosine op").SetMultiple(); AddInput("input", "input of cosine op").AsDuplicable();
AddOutput("output", "output of cosine op").SetTemporary(); AddOutput("output", "output of cosine op").AsIntermediate();
auto my_checker = [](int i) { auto my_checker = [](int i) {
PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!");
}; };
...@@ -53,16 +51,24 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -53,16 +51,24 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(cos_sim, paddle::framework::CosineOp, static void BuildVar(const std::string& param_name,
paddle::framework::CosineOpProtoAndCheckerMaker); std::initializer_list<const char*> arguments,
REGISTER_OP(my_test_op, paddle::framework::MyTestOp, paddle::framework::OpDesc::Var* var) {
paddle::framework::MyTestOpProtoAndCheckerMaker); var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
op_desc.add_inputs("aa"); BuildVar("input", {"aa"}, op_desc.add_inputs());
op_desc.add_outputs("bb"); BuildVar("output", {"bb"}, op_desc.add_outputs());
float scale = 3.3; float scale = 3.3;
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
...@@ -70,8 +76,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -70,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
std::shared_ptr<paddle::framework::OperatorBase> op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -82,8 +87,8 @@ TEST(OpRegistry, CreateOp) { ...@@ -82,8 +87,8 @@ TEST(OpRegistry, CreateOp) {
TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
op_desc.add_inputs("aa"); BuildVar("input", {"aa"}, op_desc.add_inputs());
op_desc.add_outputs("bb"); BuildVar("output", {"bb"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
...@@ -107,33 +112,23 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -107,33 +112,23 @@ TEST(OpRegistry, IllegalAttr) {
TEST(OpRegistry, DefaultValue) { TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
op_desc.add_inputs("aa"); BuildVar("input", {"aa"}, op_desc.add_inputs());
op_desc.add_outputs("bb"); BuildVar("output", {"bb"}, op_desc.add_outputs());
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
std::shared_ptr<paddle::framework::OperatorBase> op = auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
} }
static void SetInputFormat(paddle::framework::OpDesc* desc) {
auto attr = desc->add_attrs();
attr->set_name("input_format");
attr->set_type(paddle::framework::INTS);
attr->mutable_ints()->Add(0);
attr->mutable_ints()->Add(1);
}
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op"); op_desc.set_type("my_test_op");
op_desc.add_inputs("ii"); BuildVar("input", {"ii"}, op_desc.add_inputs());
op_desc.add_outputs("oo"); BuildVar("output", {"oo"}, op_desc.add_outputs());
SetInputFormat(&op_desc);
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
...@@ -173,7 +168,6 @@ TEST(OpRegistry, CustomChecker) { ...@@ -173,7 +168,6 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
SetInputFormat(&op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -34,83 +34,172 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -34,83 +34,172 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif #endif
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, auto& ins = Inputs(name);
"Input Output Indices could not be nullptr"); PADDLE_ENFORCE_EQ(ins.size(), 1UL,
auto it = in_out_idxs_->find(name); "Op %s input %s should contain only one variable", type_,
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", name);
name); return ins[0];
if (attrs_.count("input_format") == 0) {
return inputs_.at((size_t)it->second);
} else {
const auto& input_format = GetAttr<std::vector<int>>("input_format");
int idx = input_format[it->second];
return inputs_.at((size_t)idx);
}
} }
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { const std::vector<std::string>& OperatorBase::Inputs(
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "IO Idx could not be nullptr"); const std::string& name) const {
auto input_format = GetAttr<std::vector<int>>("input_format"); auto it = inputs_.find(name);
auto offset = in_out_idxs_->at(name); PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <= name);
static_cast<int>(inputs_.size()), return it->second;
"Input Out Of Range");
return std::vector<std::string>{
inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
} }
const std::string& OperatorBase::Output(const std::string& name) const { const std::string& OperatorBase::Output(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr"); auto& outs = Outputs(name);
auto it = in_out_idxs_->find(name); PADDLE_ENFORCE_EQ(outs.size(), 1UL,
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", "Op %s output %s should contain only one variable", type_,
name); name);
if (attrs_.count("output_format") == 0) { return outs[0];
return outputs_.at((size_t)it->second);
} else {
const auto& output_format = GetAttr<std::vector<int>>("output_format");
int idx = output_format[it->second];
return outputs_.at((size_t)idx);
}
} }
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { const std::vector<std::string>& OperatorBase::Outputs(
PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr"); const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format"); auto it = outputs_.find(name);
auto offset = in_out_idxs_->at(name); PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <= name);
static_cast<int>(outputs_.size()), return it->second;
"Output Out of Range");
return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
} }
std::string OperatorBase::DebugString() const { std::string OperatorBase::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "Op(" << type_ << "), inputs:("; ss << "Op(" << type_ << "), inputs:{";
for (size_t i = 0; i < inputs_.size(); ++i) { for (auto it = inputs_.begin(); it != inputs_.end();) {
ss << inputs_[i]; auto& input = *it;
if (i != inputs_.size() - 1) { ss << input.first << "[";
for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (i != input.second.size() - 1) {
ss << ", ";
}
}
ss << "]";
++it;
if (it != inputs_.end()) {
ss << ", "; ss << ", ";
} }
} }
ss << "), outputs:("; ss << "}, outputs:{";
for (size_t i = 0; i < outputs_.size(); ++i) { for (auto it = outputs_.begin(); it != outputs_.end();) {
ss << outputs_[i]; auto& output = *it;
if (i != outputs_.size() - 1) { ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (i != output.second.size() - 1) {
ss << ", ";
}
}
ss << "]";
++it;
if (it != outputs_.end()) {
ss << ", "; ss << ", ";
} }
} }
ss << ")."; ss << "}.";
return ss.str(); return ss.str();
} }
void OperatorBase::Rename(const std::string& old_name, void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) { const std::string& new_name) {
std::replace(inputs_.begin(), inputs_.end(), old_name, new_name); for (auto& input : inputs_) {
std::replace(outputs_.begin(), outputs_.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
for (auto& output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
}
OperatorBase::OperatorBase(const std::string& type,
const OperatorBase::VarNameMap& inputs,
const OperatorBase::VarNameMap& outputs,
const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
// push all outputs into ret_val
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
auto it = OpRegistry::op_info_map().find(type_);
PADDLE_ENFORCE(
it != OpRegistry::op_info_map().end(),
"Operator %s not registered, cannot figure out intermediate outputs",
type_);
PADDLE_ENFORCE(
it->second.proto_ != nullptr,
"Operator %s has no OpProto, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs
for (auto& o : it->second.proto_->outputs()) {
// ignore all intermediate output
if (o.intermediate()) continue;
auto out = outputs_.find(o.name());
if (out != outputs_.end()) {
ret_val.reserve(ret_val.size() + out->second.size());
ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
}
}
return ret_val;
}
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
} }
} // namespace framework } // namespace framework
......
...@@ -20,8 +20,7 @@ limitations under the License. */ ...@@ -20,8 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -63,16 +62,10 @@ class ExecutionContext; ...@@ -63,16 +62,10 @@ class ExecutionContext;
*/ */
class OperatorBase { class OperatorBase {
public: public:
OperatorBase() {} // TODO(yi): This constructor is to be removed. using VarNameMap = std::map<std::string, std::vector<std::string>>;
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, OperatorBase(const std::string& type, const VarNameMap& inputs,
const AttributeMap& attrs, const VarNameMap& outputs, const AttributeMap& attrs);
std::unordered_map<std::string, int>* in_out_idxs)
: type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
in_out_idxs_(in_out_idxs) {}
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
...@@ -85,10 +78,6 @@ class OperatorBase { ...@@ -85,10 +78,6 @@ class OperatorBase {
virtual std::string DebugString() const; virtual std::string DebugString() const;
/// Init will be called after CreateOperator, you can put some initialization
/// logic here.
virtual void Init() {}
/// InferShape infer the size of Variables used by this Operator with /// InferShape infer the size of Variables used by this Operator with
/// information inside scope /// information inside scope
virtual void InferShape(const Scope& scope) const = 0; virtual void InferShape(const Scope& scope) const = 0;
...@@ -104,39 +93,134 @@ class OperatorBase { ...@@ -104,39 +93,134 @@ class OperatorBase {
/// rename inputs outputs name /// rename inputs outputs name
void Rename(const std::string& old_name, const std::string& new_name); void Rename(const std::string& old_name, const std::string& new_name);
const VarNameMap& Inputs() const { return inputs_; }
const VarNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy. const std::vector<std::string>& Inputs(const std::string& name) const;
std::vector<std::string> Inputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const; const std::string& Output(const std::string& name) const;
//! Get an output which has multiple variables. //! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const; const std::vector<std::string>& Outputs(const std::string& name) const;
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
const std::string Type() const { return type_; } const std::string& Type() const { return type_; }
const std::vector<std::string> Inputs() const { return inputs_; } void SetType(const std::string& type) { type_ = type; }
const std::vector<std::string> Outputs() const { return outputs_; }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
const std::unordered_map<std::string, int>* InOutIdx() const {
return in_out_idxs_.get();
}
public: // Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs) // I (Inputs)opear
// O (Outputs) // O (Outputs)
// OG (Output Gradients) // OG (Output Gradients)
std::vector<std::string> inputs_; VarNameMap inputs_;
// NOTE: in case of OpGrad, outputs_ contains // NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients) // IG (Inputs Gradients)
std::vector<std::string> outputs_; VarNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// store the arguments' offset described in op_desc. };
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(CLS) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new CLS(*this)); \
}
// Macro for define a default constructor for Operator.
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this));
}
};
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
}; };
class InferShapeContext { class InferShapeContext {
...@@ -144,16 +228,12 @@ class InferShapeContext { ...@@ -144,16 +228,12 @@ class InferShapeContext {
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
size_t InputSize() const { return op_.inputs_.size(); } size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
size_t OutputSize() const { return op_.outputs_.size(); }
const Variable* InputVar(const size_t index) const {
return scope_.FindVar(op_.inputs_.at(index));
} }
Variable* OutputVar(const size_t index) const { size_t OutputSize(const std::string& name) const {
return scope_.FindVar(op_.outputs_.at(index)); return op_.Outputs(name).size();
} }
const Variable* InputVar(const std::string& name) const { const Variable* InputVar(const std::string& name) const {
...@@ -185,27 +265,9 @@ class InferShapeContext { ...@@ -185,27 +265,9 @@ class InferShapeContext {
return res; return res;
} }
template <typename T>
const T* Input(const size_t index) const {
auto var = InputVar(index);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
return &var->Get<T>();
}
template <typename T>
T* Output(const size_t index) const {
auto var = OutputVar(index);
PADDLE_ENFORCE_NOT_NULL(
var,
"Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope",
index, op_.outputs_[index]);
return var->GetMutable<T>();
}
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
auto var = InputVar(name); auto* var = InputVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
return &var->Get<T>(); return &var->Get<T>();
} }
...@@ -242,7 +304,7 @@ class InferShapeContext { ...@@ -242,7 +304,7 @@ class InferShapeContext {
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, "MultiOutput(%s:%s) should not be nullptr", name, var, "MultiOutput(%s:%s) should not be nullptr.", name,
sub_name); sub_name);
return var->GetMutable<T>(); return var->GetMutable<T>();
}); });
...@@ -281,6 +343,10 @@ class ExecutionContext : public InferShapeContext { ...@@ -281,6 +343,10 @@ class ExecutionContext : public InferShapeContext {
platform::Place GetPlace() const { return device_context_->GetPlace(); } platform::Place GetPlace() const { return device_context_->GetPlace(); }
const platform::DeviceContext* device_context() const {
return device_context_;
}
const platform::DeviceContext* device_context_; const platform::DeviceContext* device_context_;
}; };
...@@ -300,14 +366,6 @@ class OpKernel { ...@@ -300,14 +366,6 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
OperatorWithKernel() {} // TODO(yi): This constructor is to be removed.
OperatorWithKernel(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
...@@ -331,6 +389,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -331,6 +389,10 @@ class OperatorWithKernel : public OperatorBase {
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
OperatorWithKernel(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const Scope& scope) const override { void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope)); InferShape(InferShapeContext(*this, scope));
} }
...@@ -357,15 +419,5 @@ class OperatorWithKernel : public OperatorBase { ...@@ -357,15 +419,5 @@ class OperatorWithKernel : public OperatorBase {
virtual void InferShape(const InferShapeContext& ctx) const = 0; virtual void InferShape(const InferShapeContext& ctx) const = 0;
}; };
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const std::vector<std::string>& inputs, \
const std::vector<std::string>& outputs, \
const ::paddle::framework::AttributeMap& attrs, \
std::unordered_map<std::string, int>* in_out_idxs) \
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -23,22 +23,22 @@ static int op_run_num = 0; ...@@ -23,22 +23,22 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase) OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs)
void Init() override { x = 1; } : OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
op_run_num++; ++op_run_num;
ASSERT_EQ((int)inputs_.size(), 1); ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
ASSERT_EQ((int)outputs_.size(), 1); ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr); ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
ASSERT_EQ(x, 1); ASSERT_EQ(x, 1);
ASSERT_NE(scope.FindVar(outputs_[0]), nullptr); ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
} }
public: public:
float x = 0; int x{0};
}; };
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...@@ -56,14 +56,25 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { ...@@ -56,14 +56,25 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, static void BuildVar(const std::string& param_name,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
*var->mutable_arguments()->Add() = arg_name;
}
}
REGISTER_OP_WITHOUT_GRADIENT(
test_operator, paddle::framework::OpWithoutKernelTest,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("test_operator"); op_desc.set_type("test_operator");
*op_desc.mutable_inputs()->Add() = "IN1"; BuildVar("input", {"IN1"}, op_desc.add_inputs());
*op_desc.mutable_outputs()->Add() = "OUT1"; BuildVar("output", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
...@@ -100,7 +111,8 @@ static int cpu_kernel_run_num = 0; ...@@ -100,7 +111,8 @@ static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
public: public:
DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel) using OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override {} void InferShape(const framework::InferShapeContext& ctx) const override {}
}; };
...@@ -117,35 +129,15 @@ class CPUKernelTest : public OpKernel { ...@@ -117,35 +129,15 @@ class CPUKernelTest : public OpKernel {
} }
}; };
// multiple inputs test
class OperatorMultiInputsTest : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase)
void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr);
ASSERT_EQ(x, 1);
ASSERT_NE(scope.FindVar(outputs_[0]), nullptr);
ASSERT_EQ(Input("x"), "IN1");
ASSERT_EQ(Input("y"), "OUT1");
}
public:
float x = 0;
};
class OpKernelTestMultiInputsProtoAndCheckerMaker class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker { : public OpProtoAndCheckerMaker {
public: public:
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
OpAttrChecker* op_checker) OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("xs", "inputs of test op").SetMultiple(); AddInput("xs", "inputs of test op").AsDuplicable();
AddInput("k", "input of test op"); AddInput("k", "input of test op");
AddOutput("ys", "outputs of test op").SetMultiple(); AddOutput("ys", "outputs of test op").AsDuplicable();
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
...@@ -193,8 +185,9 @@ class CPUKernalMultiInputsTest : public OpKernel { ...@@ -193,8 +185,9 @@ class CPUKernalMultiInputsTest : public OpKernel {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpKernelTestProtoAndCheckerMaker); op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>); paddle::framework::CPUKernelTest<float, float>);
...@@ -202,8 +195,9 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, ...@@ -202,8 +195,9 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("op_with_kernel"); op_desc.set_type("op_with_kernel");
*op_desc.mutable_inputs()->Add() = "IN1"; BuildVar("x", {"IN1"}, op_desc.add_inputs());
*op_desc.mutable_outputs()->Add() = "OUT1"; BuildVar("y", {"OUT1"}, op_desc.add_outputs());
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
...@@ -218,8 +212,9 @@ TEST(OpKernel, all) { ...@@ -218,8 +212,9 @@ TEST(OpKernel, all) {
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
} }
REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
paddle::framework::CPUKernalMultiInputsTest); paddle::framework::CPUKernalMultiInputsTest);
...@@ -229,32 +224,15 @@ TEST(OpKernel, multi_inputs) { ...@@ -229,32 +224,15 @@ TEST(OpKernel, multi_inputs) {
OpDesc op_desc; OpDesc op_desc;
op_desc.set_type("op_multi_inputs_with_kernel"); op_desc.set_type("op_multi_inputs_with_kernel");
*op_desc.mutable_inputs()->Add() = "x0"; BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
*op_desc.mutable_inputs()->Add() = "x1"; BuildVar("k", {"k0"}, op_desc.add_inputs());
*op_desc.mutable_inputs()->Add() = "x2"; BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
*op_desc.mutable_inputs()->Add() = "k0";
*op_desc.mutable_outputs()->Add() = "y0";
*op_desc.mutable_outputs()->Add() = "y1";
auto attr = op_desc.mutable_attrs()->Add(); auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale"); attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.14); attr->set_f(3.14);
auto attr0 = op_desc.mutable_attrs()->Add();
attr0->set_name("input_format");
attr0->set_type(paddle::framework::AttrType::INTS);
auto input_format = attr0->mutable_ints();
input_format->Add(0); // x0
input_format->Add(3); // k
input_format->Add(4); // end
auto attr1 = op_desc.mutable_attrs()->Add();
attr1->set_name("output_format");
attr1->set_type(paddle::framework::AttrType::INTS);
auto output_format = attr1->mutable_ints();
output_format->Add(0); // y0
output_format->Add(2); // y1
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
scope.NewVar("x0")->GetMutable<Tensor>(); scope.NewVar("x0")->GetMutable<Tensor>();
...@@ -267,3 +245,21 @@ TEST(OpKernel, multi_inputs) { ...@@ -267,3 +245,21 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
class OperatorClone : public paddle::framework::OperatorBase {
public:
DEFINE_OP_CLONE_METHOD(OperatorClone);
OperatorClone(const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const paddle::framework::Scope& scope) const override {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {}
};
TEST(Operator, Clone) {
OperatorClone a("ABC", {}, {}, {});
auto b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
}
\ No newline at end of file
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor_py.h" #include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "paddle/string/to_string.h" #include "paddle/string/to_string.h"
...@@ -30,8 +31,8 @@ limitations under the License. */ ...@@ -30,8 +31,8 @@ limitations under the License. */
namespace py = pybind11; namespace py = pybind11;
USE_OP(add_two); USE_OP(add_two);
USE_CPU_OP(onehot_cross_entropy); USE_OP(onehot_cross_entropy);
USE_NO_GRAD_OP(sgd); USE_OP(sgd);
USE_OP(mul); USE_OP(mul);
USE_OP(mean); USE_OP(mean);
USE_OP(sigmoid); USE_OP(sigmoid);
...@@ -47,41 +48,6 @@ namespace framework { ...@@ -47,41 +48,6 @@ namespace framework {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type &op) -> std::string {
return op.type_;
})
.def("outputs",
[](const typename ClassType::type &op) -> std::vector<std::string> {
return op.outputs_;
})
.def("inputs",
[](const typename ClassType::type &op) -> std::vector<std::string> {
return op.inputs_;
})
.def("support_gpu", &ClassType::type::SupportGPU)
.def("temp_outputs",
[](const typename ClassType::type &op) -> std::vector<std::string> {
auto iter = op.attrs_.find("temporary_index");
std::vector<std::string> ret;
if (iter == op.attrs_.end()) {
return ret;
} else {
auto tmp_idx = boost::get<std::vector<int>>(iter->second);
ret.reserve(tmp_idx.size());
for (auto &index : tmp_idx) {
ret.push_back(op.outputs_.at(index));
}
return ret;
}
})
.def("__str__", &ClassType::type::DebugString);
}
static size_t UniqueIntegerGenerator() { static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator; static std::atomic<size_t> generator;
return generator.fetch_add(1); return generator.fetch_add(1);
...@@ -172,13 +138,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -172,13 +138,16 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &protos = OpRegistry::protos(); auto &op_info_map = OpRegistry::op_info_map();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), const OpProto *proto = it->second.proto_;
"OpProto must all be initialized"); if (proto == nullptr) {
continue;
}
PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str; std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&str), PADDLE_ENFORCE(proto->SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str)); ret_values.push_back(py::bytes(str));
} }
...@@ -215,47 +184,69 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -215,47 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>()) .def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>); .def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base( py::class_<OperatorBase>(m, "Operator")
m, "Operator"); .def_static("create",
[](py::bytes protobin) {
operator_base.def_static("create", [](py::bytes protobin) { OpDesc desc;
OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc");
"Cannot parse user input to OpDesc"); PADDLE_ENFORCE(desc.IsInitialized(),
PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s",
"User OpDesc is not initialized, reason %s", desc.InitializationErrorString());
desc.InitializationErrorString()); return OpRegistry::CreateOp(desc);
return OpRegistry::CreateOp(desc); })
}); .def("backward",
[](const OperatorBase &forwardOp,
operator_base.def("backward", const std::unordered_set<std::string> &no_grad_vars) {
[](const OperatorBase &forwardOp, return Backward(forwardOp, no_grad_vars).release();
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars);
});
ExposeOperator(operator_base);
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &operators::NetOp::AddOp)
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
}) })
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
[](const OperatorBase &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU);
py::class_<operators::NetOp, OperatorBase>(m, "Net")
.def_static("create",
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
.def("append_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AppendOp(op); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) { .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp(); self->CompleteAddOp();
}); });
ExposeOperator(net); // recurrent_op
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
});
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
......
...@@ -105,6 +105,8 @@ class Tensor { ...@@ -105,6 +105,8 @@ class Tensor {
template <typename T> template <typename T>
inline Tensor Slice(const int& begin_idx, const int& end_idx) const; inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const { return holder_->place(); }
private: private:
template <typename T> template <typename T>
inline void check_memory_size() const; inline void check_memory_size() const;
......
...@@ -4,6 +4,10 @@ file(GLOB cpp_files . *Op.cpp) ...@@ -4,6 +4,10 @@ file(GLOB cpp_files . *Op.cpp)
list(APPEND h_files Function.h) list(APPEND h_files Function.h)
list(APPEND cpp_files Function.cpp) list(APPEND cpp_files Function.cpp)
list(APPEND cpp_files BufferArg.cpp) list(APPEND cpp_files BufferArg.cpp)
list(APPEND cpp_files GemmFunctor.cpp)
if(USE_EIGEN_FOR_BLAS)
list(APPEND cpp_files EigenGemm.cpp)
endif(USE_EIGEN_FOR_BLAS)
if(WITH_GPU) if(WITH_GPU)
file(GLOB cu_files . *OpGpu.cu) file(GLOB cu_files . *OpGpu.cu)
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "ConvOp.h" #include "ConvOp.h"
#include "GemmFunctor.h"
namespace paddle { namespace paddle {
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/BaseMatrix.h" #include "paddle/math/BaseMatrix.h"
namespace paddle { namespace paddle {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
template <class T>
struct EigenBlasGemm {
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, int>,
Eigen::Aligned>
Matrix;
static void compute(const bool transA,
const bool transB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
Eigen::array<int, 2> sizeA;
if (transA) {
sizeA[0] = K;
sizeA[1] = M;
CHECK_EQ(M, lda);
} else {
sizeA[0] = M;
sizeA[1] = K;
CHECK_EQ(K, lda);
}
Eigen::array<int, 2> sizeB;
if (transB) {
sizeB[0] = N;
sizeB[1] = K;
CHECK_EQ(K, ldb);
} else {
sizeB[0] = K;
sizeB[1] = N;
CHECK_EQ(N, ldb);
}
Eigen::array<int, 2> sizeC;
sizeC[0] = M;
sizeC[1] = N;
CHECK_EQ(N, ldc);
const Matrix a(const_cast<T*>(A), sizeA);
const Matrix b(const_cast<T*>(B), sizeB);
Matrix c(C, sizeC);
typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims;
dims[0] = DimPair(1, 0);
dims[0].first = transA ? 0 : 1;
dims[0].second = transB ? 1 : 0;
Eigen::DefaultDevice device;
if (alpha == T(1) && beta == T(0)) {
c.device(device) = a.contract(b, dims);
} else if (alpha == T(1) && beta == T(1)) {
c.device(device) += a.contract(b, dims);
} else {
c.device(device) = alpha * a.contract(b, dims) + beta * c;
}
}
};
#ifdef PADDLE_TYPE_DOUBLE
template class EigenBlasGemm<double>;
#else
template class EigenBlasGemm<float>;
#endif
} // namespace paddle
...@@ -85,7 +85,6 @@ public: ...@@ -85,7 +85,6 @@ public:
} }
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -108,19 +107,19 @@ public: ...@@ -108,19 +107,19 @@ public:
int M = outputChannels / groups_; int M = outputChannels / groups_;
int N = outputHeight * outputWidth; int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth; int K = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans, BlasGemm<Device, real>::compute(false,
CblasNoTrans, false,
M, M,
N, N,
K, K,
1.0f, 1.0f,
filterData + g * filterOffset, filterData + g * filterOffset,
K, K,
colData, colData,
N, N,
beta, beta,
outputData + g * outputOffset, outputData + g * outputOffset,
N); N);
} }
inputData += inputChannels * inputHeight * inputWidth; inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
...@@ -188,8 +187,6 @@ public: ...@@ -188,8 +187,6 @@ public:
} }
Col2ImFunctor<kCFO, Device, real> col2im; Col2ImFunctor<kCFO, Device, real> col2im;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -205,19 +202,19 @@ public: ...@@ -205,19 +202,19 @@ public:
colData = inputGrad + g * inputOffset; colData = inputGrad + g * inputOffset;
scale = 1.0f; scale = 1.0f;
} }
gemm(CblasTrans, BlasGemm<Device, real>::compute(true,
CblasNoTrans, false,
M, M,
N, N,
K, K,
1.0f, 1.0f,
filterData + g * filterOffset, filterData + g * filterOffset,
M, M,
outputGrad + g * outputOffset, outputGrad + g * outputOffset,
N, N,
scale, scale,
colData, colData,
N); N);
if (needIm2col) { if (needIm2col) {
col2im(inputGrad + g * inputOffset, col2im(inputGrad + g * inputOffset,
imShape, imShape,
...@@ -299,7 +296,6 @@ public: ...@@ -299,7 +296,6 @@ public:
} }
Im2ColFunctor<kCFO, Device, real> im2col; Im2ColFunctor<kCFO, Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = imShape.getElements(); size_t inputOffset = imShape.getElements();
size_t outputOffset = size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth; (outputChannels / groups_) * outputHeight * outputWidth;
...@@ -321,19 +317,19 @@ public: ...@@ -321,19 +317,19 @@ public:
int M = outputChannels / groups_; int M = outputChannels / groups_;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
int N = inputChannels / groups_ * filterHeight * filterWidth; int N = inputChannels / groups_ * filterHeight * filterWidth;
gemm(CblasNoTrans, BlasGemm<Device, real>::compute(false,
CblasTrans, true,
M, M,
N, N,
K, K,
1.0f, 1.0f,
outputGrad + g * outputOffset, outputGrad + g * outputOffset,
K, K,
colData, colData,
K, K,
i == 0 ? beta : 1.0f, i == 0 ? beta : 1.0f,
filterGrad + g * filterOffset, filterGrad + g * filterOffset,
N); N);
} }
inputData += inputChannels * inputHeight * inputWidth; inputData += inputChannels * inputHeight * inputWidth;
outputGrad += outputChannels * outputHeight * outputWidth; outputGrad += outputChannels * outputHeight * outputWidth;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "GemmFunctor.h"
#include "paddle/math/MathFunctions.h"
namespace paddle {
template <class T>
struct BlasGemm<DEVICE_TYPE_CPU, T> {
static void compute(const bool transA,
const bool transB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
#ifdef PADDLE_USE_EIGEN_FOR_BLAS
EigenBlasGemm<T>::compute(
transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
#else
gemm<T>(transA == false ? CblasNoTrans : CblasTrans,
transB == false ? CblasNoTrans : CblasTrans,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
#endif
}
};
template <class T>
struct BlasGemm<DEVICE_TYPE_GPU, T> {
static void compute(const bool transA,
const bool transB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
hl_matrix_mul((T*)A,
transA == false ? HPPL_OP_N : HPPL_OP_T,
(T*)B,
transB == false ? HPPL_OP_N : HPPL_OP_T,
C,
M,
N,
K,
alpha,
beta,
lda,
ldb,
ldc);
}
};
template struct BlasGemm<DEVICE_TYPE_CPU, real>;
template struct BlasGemm<DEVICE_TYPE_GPU, real>;
} // namespace paddle
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/math/MathFunctions.h" #include "TensorType.h"
namespace paddle { namespace paddle {
...@@ -24,73 +24,42 @@ namespace paddle { ...@@ -24,73 +24,42 @@ namespace paddle {
// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul // of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul
// interface. // interface.
template <DeviceType Device, class T> template <DeviceType Device, class T>
class GemmFunctor { struct BlasGemm {
public: static void compute(const bool transA,
void operator()(const CBLAS_TRANSPOSE transA, const bool transB,
const CBLAS_TRANSPOSE TransB, const int M,
const int M, const int N,
const int N, const int K,
const int K, const T alpha,
const T alpha, const T* A,
const T* A, const int lda,
const int lda, const T* B,
const T* B, const int ldb,
const int ldb, const T beta,
const T beta, T* C,
T* C, const int ldc);
const int ldc);
}; };
// TODO(hedaoyuan): Since the definition of the real type in the Paddle
// conflicts with the Eigen library, so compile the Eigen code can not
// include the Paddle header file. And need an EigenBlasGemm template class
// that does not contain the DeviceType parameter.
// I will fix this problem and merge BlasGemm and EigenBlasGemm into one.
template <class T> template <class T>
class GemmFunctor<DEVICE_TYPE_CPU, T> { struct EigenBlasGemm {
public: static void compute(const bool transA,
void operator()(const CBLAS_TRANSPOSE transA, const bool transB,
const CBLAS_TRANSPOSE TransB, const int M,
const int M, const int N,
const int N, const int K,
const int K, const T alpha,
const T alpha, const T* A,
const T* A, const int lda,
const int lda, const T* B,
const T* B, const int ldb,
const int ldb, const T beta,
const T beta, T* C,
T* C, const int ldc);
const int ldc) {
gemm<T>(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
}
};
template <class T>
class GemmFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
hl_matrix_mul((T*)A,
transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
(T*)B,
TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T,
C,
M,
N,
K,
alpha,
beta,
lda,
ldb,
ldc);
}
}; };
} // namespace paddle } // namespace paddle
...@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) { ...@@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector<Argument>& inArgs) {
auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>( auto mat = dynamic_cast<SparsePrefetchRowCpuMatrix*>(
para->getMat(PARAMETER_VALUE).get()); para->getMat(PARAMETER_VALUE).get());
para->clearGradient(); para->clearGradient();
mat->clearIndices(); if (mat) mat->clearIndices();
} }
} }
} }
......
...@@ -184,7 +184,7 @@ public: ...@@ -184,7 +184,7 @@ public:
} }
void backward(const UpdateCallback& callback) override { void backward(const UpdateCallback& callback) override {
if (biases_) { if (biases_ && biases_->getWGrad()) {
backwardActivation(); backwardActivation();
biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
......
...@@ -57,11 +57,14 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, ...@@ -57,11 +57,14 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
} }
void MKLDNNFcLayer::convertWeightsFromPaddle() { void MKLDNNFcLayer::convertWeightsFromPaddle() {
if (FLAGS_use_mkldnn_wgt) { if (hasInitedWgt_) {
return; return;
} }
if (hasInitedWgt_) { // TODO(TJ): dst format should get from wgtVal_
int dstFmt = PARAM_FORMAT_MKLDNN_OI;
int srcFmt = weight_->getParameterPtr()->getHeaderFormat();
if (srcFmt == dstFmt) {
return; return;
} }
...@@ -78,6 +81,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { ...@@ -78,6 +81,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() {
MatrixPtr paddleWgtT; MatrixPtr paddleWgtT;
paddleWgt->transpose(paddleWgtT, true); paddleWgt->transpose(paddleWgtT, true);
weight_->getW()->copyFrom(*paddleWgtT); weight_->getW()->copyFrom(*paddleWgtT);
weight_->getParameterPtr()->setHeaderFormat(dstFmt);
hasInitedWgt_ = true; hasInitedWgt_ = true;
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Layer.h"
namespace paddle {
/**
* A layer applies a linear transformation to each element in each row of
* the input matrix. For each element, the layer first re-scale it and then
* adds a bias to it.
*
* \f[
* y = wx + b
* \f]
*
* Here, w is the scale and b is the bias. Both w and b are trainable scalars.
*
*/
class ScaleShiftLayer : public Layer {
protected:
std::unique_ptr<Weight> scale_;
std::unique_ptr<Weight> offset_;
public:
explicit ScaleShiftLayer(const LayerConfig& config) : Layer(config) {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
};
REGISTER_LAYER(scale_shift, ScaleShiftLayer);
bool ScaleShiftLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
CHECK_EQ(inputLayers_.size(), 1U);
scale_.reset(new Weight(1, 1, parameters_[0]));
if (biasParameter_.get() != NULL) {
offset_ = std::unique_ptr<Weight>(new Weight(1, 1, biasParameter_));
}
return true;
}
void ScaleShiftLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV = getInputValue(0);
resetOutput(inV->getHeight(), inV->getWidth());
MatrixPtr outV = getOutputValue();
real scaleValue = scale_->getW()->getElement(0, 0);
outV->mulScalar(*inV, scaleValue);
if (offset_) {
real offsetValue = offset_->getW()->getElement(0, 0);
outV->add(offsetValue);
}
}
void ScaleShiftLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV = getInputValue(0);
MatrixPtr inG = getInputGrad(0);
MatrixPtr outV = getOutputValue();
MatrixPtr outG = getOutputGrad();
/* Calculate the parameter gradient for the current layer */
if (scale_->getWGrad()) {
MatrixPtr rowSumMtx;
Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_);
// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij}
rowSumMtx->sumOfProducts(
/* b= */ *inV, /* c= */ *outG, /* scaleSum= */ 1, /* scaleDest= */ 0.);
// this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji}
scale_->getWGrad()->sumCols(
/* b= */ *rowSumMtx, /* scaleSum= */ 1., /* scaleDest= */ 1.);
scale_->getParameterPtr()->incUpdate(callback);
}
if (offset_ && offset_->getWGrad()) {
MatrixPtr rowSumMtx;
Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_);
rowSumMtx->sumRows(*outG, 1., 0.);
offset_->getWGrad()->sumCols(*rowSumMtx, 1., 1.);
offset_->getParameterPtr()->incUpdate(callback);
}
/* Calculate the input layers error */
if (inG) {
real scaleValue = scale_->getW()->getElement(0, 0);
inG->add(*outG, scaleValue);
}
}
} // namespace paddle
...@@ -330,9 +330,7 @@ void MKLDNNTester::run(const TestConfig& dnn, ...@@ -330,9 +330,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
log_ = log; log_ = log;
lvl_ = level; lvl_ = level;
// Firstly test FLAGS_use_mkldnn_wgt = false // Firstly test mkldnn init from PARAM_FORMAT_ORIGINAL weight
FLAGS_use_mkldnn_wgt = false;
// reset and run once
reset(dnn, ref, batchSize); reset(dnn, ref, batchSize);
randomWgtDatas(); randomWgtDatas();
clearWgtDiffs(); clearWgtDiffs();
...@@ -342,17 +340,32 @@ void MKLDNNTester::run(const TestConfig& dnn, ...@@ -342,17 +340,32 @@ void MKLDNNTester::run(const TestConfig& dnn,
runOnce(); runOnce();
} }
// Then test FLAGS_use_mkldnn_wgt = true if (parameters_[DNN].empty()) {
FLAGS_use_mkldnn_wgt = true; // has no paramters
// after run once the mkldnn weight has been stored in dnnlayer return;
}
// After run some iterations, the mkldnn weight has been stored in dnnLayer
// and we can also get the mkldnn weight parameter header format.
// Weight parameter should always be index 0 (and bias index 1).
// TODO(TJ): should also consider mean and var format when batchnorm ready
int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat();
int refWgtFmt = parameters_[REF][0]->getHeaderFormat();
if (dnnWgtFmt == refWgtFmt) {
// weight format are equal, so no need check more
return;
}
// then save the weights and restart again // then save the weights and restart again
vector<VectorPtr> dnnWgts, refWgts; vector<VectorPtr> dnnWgts, refWgts;
CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size());
saveWgt(parameters_[DNN], dnnWgts); saveWgt(parameters_[DNN], dnnWgts);
saveWgt(parameters_[REF], refWgts); saveWgt(parameters_[REF], refWgts);
// restart again with flag true // restart again with dnn weight format
reset(dnn, ref, batchSize); reset(dnn, ref, batchSize);
// TODO(TJ): should also considerate mean and var format when batchnorm ready
parameters_[DNN][0]->setHeaderFormat(dnnWgtFmt);
// restore wgt // restore wgt
restoreWgt(dnnWgts, parameters_[DNN]); restoreWgt(dnnWgts, parameters_[DNN]);
......
...@@ -108,7 +108,7 @@ private: ...@@ -108,7 +108,7 @@ private:
* if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the * if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the
* max(diff/ref) * max(diff/ref)
* else return sum(abs(a-b)) / sum(abs(b)) * else return sum(abs(a-b)) / sum(abs(b))
* The return value should smaller than eps when passing. * The return value should be smaller than eps when passing.
*/ */
double getDelta(const real* d1, double getDelta(const real* d1,
const real* d2, const real* d2,
......
...@@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) { ...@@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) {
} }
} }
TEST(Layer, ScaleShiftLayer) {
const size_t batchSize = 16;
const size_t size = 32;
TestConfig config;
config.layerConfig.set_type("scale_shift");
config.layerConfig.set_size(size);
config.biasSize = 1;
config.inputDefs.push_back(
{INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
initMain(argc, argv); initMain(argc, argv);
......
...@@ -269,7 +269,8 @@ TEST(Compare, img_conv2) { ...@@ -269,7 +269,8 @@ TEST(Compare, img_conv2) {
bool useGpu = FLAGS_use_gpu; bool useGpu = FLAGS_use_gpu;
double eps = FLAGS_checkgrad_eps; double eps = FLAGS_checkgrad_eps;
FLAGS_use_gpu = true; FLAGS_use_gpu = true;
FLAGS_checkgrad_eps = 1e-2; // Sometimes, this unit test will fail with 1e-2
FLAGS_checkgrad_eps = 4e-2;
compareNetwork(config_file_a, config_file_b); compareNetwork(config_file_a, config_file_b);
FLAGS_use_gpu = useGpu; FLAGS_use_gpu = useGpu;
FLAGS_checkgrad_eps = eps; FLAGS_checkgrad_eps = eps;
......
add_subdirectory(detail) add_subdirectory(detail)
cc_library(memory SRCS memory.cc) cc_library(memory SRCS memory.cc)
cc_library(memcpy SRCS memcpy.cc DEPS device_context) cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory cc_library(paddle_memory
DEPS DEPS
......
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
// between host and device. Allocates too much would reduce the amount // between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we // of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory. // should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, false, "If set, allocate cpu pinned memory."); DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory.");
namespace paddle { namespace paddle {
namespace memory { namespace memory {
......
...@@ -16,8 +16,6 @@ limitations under the License. */ ...@@ -16,8 +16,6 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {
......
...@@ -13,22 +13,38 @@ See the License for the specific language governing permissions and ...@@ -13,22 +13,38 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include <algorithm> // for transform
#include <cstring> // for memcpy
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include "glog/logging.h"
#include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h" #include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/gpu_info.h"
#include <cstring> // for memcpy DECLARE_double(fraction_of_gpu_memory_to_use);
namespace paddle { namespace paddle {
namespace memory { namespace memory {
detail::BuddyAllocator* GetCPUBuddyAllocator() { using BuddyAllocator = detail::BuddyAllocator;
static detail::BuddyAllocator* a = nullptr;
if (a == nullptr) { std::once_flag cpu_allocator_flag;
a = new detail::BuddyAllocator(new detail::CPUAllocator, std::once_flag gpu_allocator_flag;
platform::CpuMinChunkSize(),
platform::CpuMaxChunkSize()); BuddyAllocator* GetCPUBuddyAllocator() {
} static std::unique_ptr<BuddyAllocator> a{nullptr};
return a;
std::call_once(cpu_allocator_flag, [&]() {
a.reset(new BuddyAllocator(new detail::CPUAllocator,
platform::CpuMinChunkSize(),
platform::CpuMaxChunkSize()));
});
return a.get();
} }
template <> template <>
...@@ -48,20 +64,36 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) { ...@@ -48,20 +64,36 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
static detail::BuddyAllocator** as = NULL; using BuddyAllocVec = std::vector<BuddyAllocator*>;
if (as == NULL) { static std::unique_ptr<BuddyAllocVec, void (*)(BuddyAllocVec * p)> as{
new BuddyAllocVec, [](BuddyAllocVec* p) {
std::for_each(p->begin(), p->end(),
[](BuddyAllocator* p) { delete p; });
}};
// GPU buddy allocators
auto& allocators = *as.get();
// GPU buddy allocator initialization
std::call_once(gpu_allocator_flag, [&]() {
int gpu_num = platform::GetDeviceCount(); int gpu_num = platform::GetDeviceCount();
as = new detail::BuddyAllocator*[gpu_num]; allocators.reserve(gpu_num);
for (int gpu = 0; gpu < gpu_num; gpu++) { for (int gpu = 0; gpu < gpu_num; gpu++) {
platform::SetDeviceId(gpu); platform::SetDeviceId(gpu);
as[gpu] = new detail::BuddyAllocator(new detail::GPUAllocator, allocators.emplace_back(new BuddyAllocator(new detail::GPUAllocator,
platform::GpuMinChunkSize(), platform::GpuMinChunkSize(),
platform::GpuMaxChunkSize()); platform::GpuMaxChunkSize()));
} }
} VLOG(3) << "\n\nNOTE: each GPU device use "
<< FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n"
<< "You can set environment variable '"
<< platform::kEnvFractionGpuMemoryToUse
<< "' to change the fraction of GPU usage.\n\n";
});
platform::SetDeviceId(gpu_id); platform::SetDeviceId(gpu_id);
return as[gpu_id]; return allocators[gpu_id];
} }
template <> template <>
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
namespace paddle { namespace paddle {
......
...@@ -41,8 +41,11 @@ function(op_library TARGET) ...@@ -41,8 +41,11 @@ function(op_library TARGET)
endif() endif()
endfunction() endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
...@@ -50,7 +53,7 @@ op_library(add_op SRCS add_op.cc add_op.cu) ...@@ -50,7 +53,7 @@ op_library(add_op SRCS add_op.cc add_op.cu)
op_library(mean_op SRCS mean_op.cc mean_op.cu) op_library(mean_op SRCS mean_op.cc mean_op.cu)
op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function)
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
...@@ -62,7 +65,6 @@ op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) ...@@ -62,7 +65,6 @@ op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS op_desc tensor op_registry operator net_op) DEPS framework_proto tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu) SRCS uniform_random_op.cc uniform_random_op.cu)
...@@ -18,17 +18,15 @@ namespace paddle { ...@@ -18,17 +18,15 @@ namespace paddle {
namespace operators { namespace operators {
class AddOp : public framework::OperatorWithKernel { class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1); ctx.Input<Tensor>("Y")->dims(),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set"); "Two input of Add Op's dimension must be same.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
"Outputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of Add Op's dimension must be same.");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
...@@ -48,7 +46,9 @@ The equation is: Out = X + Y ...@@ -48,7 +46,9 @@ The equation is: Out = X + Y
}; };
class AddOpGrad : public framework::OperatorWithKernel { class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override {} void InferShape(const framework::InferShapeContext &ctx) const override {}
}; };
...@@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>); ops::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -28,9 +28,9 @@ template <typename Place, typename T> ...@@ -28,9 +28,9 @@ template <typename Place, typename T>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>(0); auto* input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>(1); auto* input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0); auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
......
...@@ -18,36 +18,31 @@ namespace paddle { ...@@ -18,36 +18,31 @@ namespace paddle {
namespace operators { namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel { class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, auto *X = ctx.Input<Tensor>("X");
"Input size of OnehotCrossEntropyOp must be two"); auto *label = ctx.Input<Tensor>("label");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1,
"Output size of OnehotCrossEntropyOp must be one"); PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
"0-th input of OnehotCrossEntropyOp should be set"); PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1), ctx.Output<Tensor>("Y")->Resize({X->dims()[0]});
"1-th input of OnehotCrossEntropyOp should be set");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0),
"Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>(0)->dims().size(), 2);
PADDLE_ENFORCE_EQ(ctx.Output<Tensor>(0)->dims().size(), 1,
"label's dimension must be 1.");
ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
} }
}; };
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp, public:
framework::OperatorWithKernel) using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
// TODO(superjom) add enforce here after helper functions ready dX->Resize(X->dims());
X_grad->Resize(X->dims());
} }
}; };
...@@ -72,12 +67,9 @@ OnehotCrossEntropy Operator. ...@@ -72,12 +67,9 @@ OnehotCrossEntropy Operator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
REGISTER_OP_CPU_KERNEL( ops::OnehotCrossEntropyGradientOp);
onehot_cross_entropy, REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyOpKernel<float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad, REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp); ops::OnehotCrossEntropyGradientOpKernel<float>);
REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -12,10 +12,122 @@ ...@@ -12,10 +12,122 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h" #include "paddle/platform/assert.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__host__ __device__ T clipping_log(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20;
T v = log(x);
if (v == INFINITY) {
return kApproInf;
}
if (v == -INFINITY) {
return -kApproInf;
}
return v;
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
const int N, const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
Y[i] = -clipping_log(X[i * D + label[i]]);
}
}
// TODO(qingqing): make zero setting an common function.
template <typename T>
__global__ void zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
}
}
template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
const int* label, const int N,
const int D) {
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
// CUDA_1D_KERNEL_LOOP(i, N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
int idx = i * D + label[i];
dX[idx] = -dY[i] / X[idx];
}
}
template <typename T>
class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>("label")->data<int>();
auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace());
T* Ydata = Y->data<T>();
int N = X->dims()[0];
int D = X->dims()[1];
int block = 512;
int grid = (N + block - 1) / block;
// TODO(qingqing) launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);
}
};
template <typename T>
class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto label = ctx.Input<Tensor>("label");
auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
auto* dYdata = dY->template data<T>();
auto* Xdata = X->template data<T>();
auto* label_data = label->data<int>();
int N = X->dims()[0];
int D = X->dims()[1];
int block = 512;
int grid = (N * D + block - 1) / block;
zero<T><<<grid, block>>>(dXdata, N * D);
grid = (N + block - 1) / block;
// TODO(qingqing): launch kernel on specified stream
// base on ExecutionContext.
CrossEntropyGradientKernel<T><<<grid, block>>>(dXdata, dYdata, Xdata,
label_data, N, D);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
onehot_cross_entropy, ops::OnehotCrossEntropyOpCUDAKernel<float>);
ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>); REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpCUDAKernel<float>);
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
T tolerable_value(T x) { inline T tolerable_value(const T x) {
static_assert(std::is_floating_point<T>::value, static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, " "tolerable_value works only on float, "
"double and double double."); "double and double double.");
...@@ -39,13 +39,16 @@ T tolerable_value(T x) { ...@@ -39,13 +39,16 @@ T tolerable_value(T x) {
return x; return x;
} }
template <typename Place, typename T> template <typename T>
class OnehotCrossEntropyOpKernel : public framework::OpKernel { class OnehotCrossEntropyOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>(); const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>(); const int* label_data = ctx.Input<Tensor>("label")->data<int>();
auto Y = ctx.Output<Tensor>("Y"); auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace()); Y->mutable_data<T>(ctx.GetPlace());
...@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { ...@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
} }
}; };
template <typename Place, typename T> template <typename T>
class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y")); auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
...@@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
const int batch_size = X->dims()[0]; const int batch_size = X->dims()[0];
const int class_num = X->dims()[1]; const int class_num = X->dims()[1];
// TODO(qingqing): make zero setting an common function.
memset(dXdata, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int index = i * class_num + label_data[i];
dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
......
...@@ -18,19 +18,13 @@ namespace paddle { ...@@ -18,19 +18,13 @@ namespace paddle {
namespace operators { namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, ctx.Output<framework::Tensor>("Dst")->Resize(
"Input size of FillZerosLikeOp must be one."); ctx.Input<framework::Tensor>("Src")->dims());
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL,
"Output size of AddOp must be one.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0),
"Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims());
} }
}; };
...@@ -52,7 +46,8 @@ The output will have the same size with input. ...@@ -52,7 +46,8 @@ The output will have the same size with input.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
...@@ -23,7 +23,7 @@ template <typename Place, typename T> ...@@ -23,7 +23,7 @@ template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel { class FillZerosLikeKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>(0); auto* output = context.Output<framework::Tensor>("Dst");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output); auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(T(0)); t.device(context.GetEigenDevice<Place>()) = t.constant(T(0));
......
...@@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size, ...@@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) { const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
int index_ = indices[i]; int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
} }
...@@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, ...@@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
// slice size // slice size
int slice_size = 1; int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering // Gathering
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
......
...@@ -35,7 +35,7 @@ TEST(Gather, GatherData) { ...@@ -35,7 +35,7 @@ TEST(Gather, GatherData) {
p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace()); p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace());
p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace()); p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace());
for (size_t i = 0; i < 12; ++i) p_src[i] = i; for (int i = 0; i < 12; ++i) p_src[i] = i;
p_index[0] = 1; p_index[0] = 1;
p_index[1] = 0; p_index[1] = 0;
...@@ -43,6 +43,10 @@ TEST(Gather, GatherData) { ...@@ -43,6 +43,10 @@ TEST(Gather, GatherData) {
Gather<int>(CPUPlace(), src, index, output); Gather<int>(CPUPlace(), src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
delete src;
delete index;
delete output;
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -19,34 +16,36 @@ namespace paddle { ...@@ -19,34 +16,36 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std"); float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
// TODO(dzh): attribute does not support unsigned int. unsigned int seed =
// And we need a global random seed configuration. static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
int seed = context.op_.GetAttr<int>("seed"); std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
std::mt19937 g(seed); engine.seed(seed);
std::normal_distribution<T> distribution(mean, std); std::normal_distribution<T> dist(mean, std);
ssize_t size = framework::product(tensor->dims()); ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) { for (ssize_t i = 0; i < size; ++i) {
data[i] = distribution(g); data[i] = dist(engine);
} }
} }
}; };
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel) public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0); auto* tensor = context.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
...@@ -66,8 +65,8 @@ Use to initialize tensor with gaussian random generator. ...@@ -66,8 +65,8 @@ Use to initialize tensor with gaussian random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f); AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f); AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed") "0 means use system wide seed")
...@@ -79,5 +78,6 @@ Use to initialize tensor with gaussian random generator. ...@@ -79,5 +78,6 @@ Use to initialize tensor with gaussian random generator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory> #include <thrust/device_ptr.h>
#include <random> #include <thrust/iterator/counting_iterator.h>
#include "paddle/platform/dynload/curand.h" #include <thrust/random.h>
#include "paddle/platform/gpu_info.h" #include <thrust/transform.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GaussianRandomKernel : public framework::OpKernel { struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
return dist(rng);
}
};
template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean"); auto* tensor = context.Output<framework::Tensor>("Out");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
int seed = context.op_.GetAttr<int>("seed"); static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
curandGenerator_t g; T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( T std = static_cast<T>(context.op_.GetAttr<float>("std"));
&g, CURAND_RNG_PSEUDO_DEFAULT)); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
PADDLE_ENFORCE( ssize_t N = framework::product(tensor->dims());
platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); thrust::transform(index_sequence_begin, index_sequence_begin + N,
platform::dynload::curandGenerateNormal( thrust::device_ptr<T>(data),
g, data, framework::product(tensor->dims()), mean, std); GaussianGenerator<T>(mean, std, seed));
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(gaussian_random,
REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); paddle::operators::GPUGaussianRandomKernel<float>);
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
}
template <>
void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));
}
template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(
transA, transB, M, N, K, alpha, matrix_a.data<float>(),
matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
bool trans_a,
const framework::Tensor& matrix_b,
bool trans_b, double alpha,
framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(
transA, transB, M, N, K, alpha, matrix_a.data<double>(),
matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
// Support continuous memory now
// If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C, platform::DeviceContext* context);
// matrix multiply with continuous memory
template <typename Place, typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context);
} // namespace math
} // namespace operators
} // namespace paddle
#include "paddle/operators/math/math_function.h"
#include "gtest/gtest.h"
#ifndef PADDLE_ONLY_CPU
TEST(math_function, notrans_mul_trans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor out_gpu;
paddle::framework::Tensor out;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place);
float* out_ptr = out.data<float>();
EXPECT_EQ(out_ptr[0], 5);
EXPECT_EQ(out_ptr[1], 14);
EXPECT_EQ(out_ptr[2], 14);
EXPECT_EQ(out_ptr[3], 50);
}
TEST(math_function, trans_mul_notrans) {
paddle::framework::Tensor input1;
paddle::framework::Tensor input1_gpu;
paddle::framework::Tensor input2_gpu;
paddle::framework::Tensor out_gpu;
paddle::framework::Tensor out;
auto* cpu_place = new paddle::platform::CPUPlace();
float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place);
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context);
out.CopyFrom<float>(out_gpu, *cpu_place);
float* out_ptr = out.data<float>();
EXPECT_EQ(out_ptr[0], 9);
EXPECT_EQ(out_ptr[1], 12);
EXPECT_EQ(out_ptr[2], 15);
EXPECT_EQ(out_ptr[3], 12);
EXPECT_EQ(out_ptr[4], 17);
EXPECT_EQ(out_ptr[5], 22);
EXPECT_EQ(out_ptr[6], 15);
EXPECT_EQ(out_ptr[7], 22);
EXPECT_EQ(out_ptr[8], 29);
}
#endif
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册