diff --git a/.clang_format.hook b/.clang_format.hook new file mode 100755 index 0000000000000000000000000000000000000000..1d928216867c0ba3897d71542fea44debf8d72a0 --- /dev/null +++ b/.clang_format.hook @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +readonly VERSION="3.8" + +version=$(clang-format -version) + +if ! [[ $version == *"$VERSION"* ]]; then + echo "clang-format version check failed." + echo "a version contains '$VERSION' is needed, but get '$version'" + echo "you can install the right version, and make an soft-link to '\$PATH' env" + exit -1 +fi + +clang-format $@ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb8c88787d37faf9ce4d7d856a307c11f1085d98..a772125df64aaf2eafe6cb9e022f62cc29043eb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,10 +19,10 @@ - id: end-of-file-fixer - repo: local hooks: - - id: clang-format + - id: clang-format-with-version-check name: clang-format description: Format files with ClangFormat. - entry: clang-format -i + entry: ./.clang_format.hook -i language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ - repo: https://github.com/PaddlePaddle/pre-commit-golang diff --git a/CMakeLists.txt b/CMakeLists.txt index c75b83e50cf9cef8290c37f88b38cdc3d77df39c..ad559672ad2f83a3d62cdf332b47c6cf1e730f70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,8 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) -option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) -option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) @@ -55,6 +55,7 @@ option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(GLIDE_INSTALL "Download and install go dependencies " ON) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) +option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) @@ -137,9 +138,9 @@ set(EXTERNAL_LIBS ) if(WITH_GPU) - list(APPEND EXTERNAL_LIB ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDA_LIBRARIES} ${CUDA_rt_LIBRARY}) if(NOT WITH_DSO) - list(APPEND EXTERNAL_LIB ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) + list(APPEND EXTERNAL_LIBS ${CUDNN_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_curand_LIBRARY}) endif(NOT WITH_DSO) endif(WITH_GPU) diff --git a/Dockerfile b/Dockerfile index 41b6729124228cec16be35d9b26da8042824b0b0..98f61ba586a681e53b435d592c8e43b1cc964139 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,9 +34,6 @@ RUN apt-get update && \ net-tools && \ apt-get clean -y -# paddle is using numpy.flip, which is introduced since 1.12.0 -RUN pip --no-cache-dir install 'numpy>=1.12.0' - # Install Go and glide RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \ tar -xz -C /usr/local && \ @@ -58,33 +55,22 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 # FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter # version util jupyter fixes this issue. RUN pip install --upgrade pip && \ - pip install -U 'protobuf==3.1.0' && \ - pip install -U wheel pillow BeautifulSoup && \ + pip install -U wheel && \ pip install -U docopt PyYAML sphinx && \ - pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \ - pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \ + pip install -U sphinx-rtd-theme==0.1.9 recommonmark + +RUN pip install pre-commit 'ipython==5.3.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ - pip install opencv-python rarfile 'scipy>=0.19.0' 'nltk>=3.2.2' + pip install opencv-python + +COPY ./python/requirements.txt /root/ +RUN pip install -r /root/requirements.txt # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 RUN apt-get install -y libssl-dev libffi-dev RUN pip install certifi urllib3[secure] -# TODO(qijun) The template library Eigen doesn't work well with GCC 5 -# coming with the default Docker image, so we switch to use GCC 4.8 -# by default. And I will check Eigen library later. - -RUN ln -sf gcc-4.8 /usr/bin/gcc && \ - ln -sf gcc-ar-4.8 /usr/bin/gcc-ar && \ - ln -sf gcc-nm-4.8 /usr/bin/gcc-nm && \ - ln -sf gcc-ranlib-4.8 /usr/bin/gcc-ranlib && \ - ln -sf gcc-4.8 /usr/bin/x86_64-linux-gnu-gcc && \ - ln -sf gcc-ar-4.8 /usr/bin/x86_64-linux-gnu-gcc-ar && \ - ln -sf gcc-nm-4.8 /usr/bin/x86_64-linux-gnu-gcc-nm && \ - ln -sf gcc-ranlib-4.8 /usr/bin/x86_64-linux-gnu-gcc-ranlib && \ - ln -sf g++-4.8 /usr/bin/g++ && \ - ln -sf g++-4.8 /usr/bin/x86_64-linux-gnu-g++ # Install woboq_codebrowser to /woboq RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \ diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 209f9078a637ac581d90212a48216eb388c477ed..51c3b918cc4ef4cf6c8052ccc14028a872309fcf 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -28,6 +28,10 @@ if(NOT WITH_TIMER) add_definitions(-DPADDLE_DISABLE_TIMER) endif(NOT WITH_TIMER) +if(USE_EIGEN_FOR_BLAS) + add_definitions(-DPADDLE_USE_EIGEN_FOR_BLAS) +endif(USE_EIGEN_FOR_BLAS) + if(NOT WITH_PROFILER) add_definitions(-DPADDLE_DISABLE_PROFILER) endif(NOT WITH_PROFILER) diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake index 69f40df51680a104c47d9335c070c570dcaff59a..2c84061ff572de4687b4d496f8ded6deee8d1011 100644 --- a/cmake/cudnn.cmake +++ b/cmake/cudnn.cmake @@ -2,7 +2,7 @@ if(NOT WITH_GPU) return() endif() -set(CUDNN_ROOT "" CACHE PATH "CUDNN ROOT") +set(CUDNN_ROOT "/usr" CACHE PATH "CUDNN ROOT") find_path(CUDNN_INCLUDE_DIR cudnn.h PATHS ${CUDNN_ROOT} ${CUDNN_ROOT}/include $ENV{CUDNN_ROOT} $ENV{CUDNN_ROOT}/include ${CUDA_TOOLKIT_INCLUDE} diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index db09232c0e69016bf18c1d981e4620e9e804ff7c..0eeccbf7d8a1df17351c8914df6dabf005802787 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -73,10 +73,18 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") -ADD_LIBRARY(cblas STATIC ${dummyfile}) +IF(${CBLAS_PROVIDER} MATCHES MKL) + ADD_LIBRARY(cblas SHARED ${dummyfile}) +ELSE() + ADD_LIBRARY(cblas STATIC ${dummyfile}) +ENDIF() TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES}) IF(NOT ${CBLAS_FOUND}) ADD_DEPENDENCIES(cblas extern_openblas) LIST(APPEND external_project_dependencies cblas) +ELSE() + IF("${CBLAS_PROVIDER}" STREQUAL "MKLML") + ADD_DEPENDENCIES(cblas mklml) + ENDIF() ENDIF(NOT ${CBLAS_FOUND}) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index b27eb71550b68b5c27e47bf067ae0df329bbd628..ff246b2eb4ed97dd14d45763569b661cefd203c8 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -9,13 +9,6 @@ function(CheckCompilerCXX11Flag) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") endif() - if(NOT ANDROID) - # TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem. - # Use Debug mode instead for now. - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9) - set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE) - endif() - endif() elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # Apple Clang is a different compiler than upstream Clang which havs different version numbers. @@ -160,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. -LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread) +LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) if(CMAKE_BUILD_TYPE STREQUAL "Debug") diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index cb330ea5e1b914587a725c9b90a33053f3fbbc3d..a4a843c610feb2b378c22c4b4097cd238ccd61ab 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -362,6 +362,11 @@ trans .. autoclass:: paddle.v2.layer.trans :noindex: +scale_shift +----------- +.. autoclass:: paddle.v2.layer.scale_shift + :noindex: + Sampling Layers =============== diff --git a/doc/design/cluster_train/large_model_dist_train.md b/doc/design/cluster_train/large_model_dist_train.md new file mode 100644 index 0000000000000000000000000000000000000000..0c4b5bc24c854b7062d509249bea9c50d42bd5f1 --- /dev/null +++ b/doc/design/cluster_train/large_model_dist_train.md @@ -0,0 +1,101 @@ +# Alalysis of large model distributed training in Paddle + +***NOTE: This is only some note for how we implemeted this scheme in V1, not a new design.*** + +## What is it + +We often encounter cases that the embedding layer parameters(sparse) are so large that we can not store it in the trainer's memory when training. So we need to put them to several servers, and fetch them row by row instead of fetch all of the parameters. + +## How to use + +Specify command-line argument like `--loadsave_parameters_in_pserver=true --ports_num_for_sparse=1 --use_old_updater=1` when starting the paddle trainer. And also add something like `--ports_num_for_sparse=1 --pserver_num_threads=5` when starting pserver processes. + +Accrodingly, configure your embedding layers like: + +```python +SPARSE_REMOTE=True + +w1 = data_layer(name="w1", size=dict_size) +emb1 = embedding_layer(input=w1, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE)) +w2 = data_layer(name="w2", size=dict_size) +emb2 = embedding_layer(input=w2, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE)) +... +``` + +## Implementation details + +```c++ +enum MatType { + MAT_NORMAL, + MAT_NORMAL_SHARED, + MAT_VALUE_SHARED, + MAT_SPARSE_ROW_IDS, + MAT_SPARSE_ROW_AUTO_GROW, + MAT_CACHE_ROW, + MAT_SPARSE_ROW, + MAT_SPARSE_ROW_PREFETCH, + MAT_SPARSE_ROW_PREFETCH_FULL_SIZE, +}; +``` + +`MAT_SPARSE_ROW_PREFETCH` is what we use when configured to fetch only row of matrix when training. + +In `trainer_internal.cpp:L93 trainOneBatch`: + +```c++ + if (config_->getOptConfig().use_sparse_remote_updater()) { + REGISTER_TIMER("prefetch"); + gradientMachine_->prefetch(inArgs); + parameterUpdater_->getParametersRemote(); + } +``` + +When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver. + +In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`: + +```c++ +if (fullSize) { + ... +} else { +getParams = [&] { + parameterClient_->getParameterSparse( + /* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType); +}; +applyL1 = [](Parameter& para, real decayRate) { + para.getMat(PARAMETER_VALUE)->applyL1(/*lr=*/1.0f, decayRate); +}; +} +``` + +Calling `parameterClient_->getParameterSparse` will do remote call to pserver's `getParameterSparse`: + +```c++ +void ParameterServer2::getParameterSparse(const SendParameterRequest& request, + std::vector& inputBuffers, + SendParameterResponse* response, + std::vector* outputBuffers) { + (void)inputBuffers; + auto& buffer = *readWriteBuffer_; + size_t numReals = 0; + for (const auto& block : request.blocks()) { + numReals += getParameterConfig(block).dims(1); + } + buffer.resize(numReals); + + VLOG(3) << "pserver: getParameterSparse, numReals=" << numReals; + + ReadLockGuard guard(parameterMutex_); + size_t offset = 0; + for (const auto& block : request.blocks()) { + size_t width = getParameterConfig(block).dims(1); + Buffer buf = {buffer.data() + offset, width}; + int type = request.send_back_parameter_type(); + sendBackParameterSparse(block, type, response, &buf, width, outputBuffers); + offset += width; + } +} +``` + +`getParameterConfig(block).dims(1)` returns the width of the current "parameter block"(a shard of parameter object), +then `getParameterSparse` remote call returns only one row of data to the client. diff --git a/doc/design/mkldnn/README.MD b/doc/design/mkldnn/README.MD index e956994431fbb43438c56dcd96ad8313cf516090..fe8da907d9d45a2164031430ac5b7a3d5523967a 100644 --- a/doc/design/mkldnn/README.MD +++ b/doc/design/mkldnn/README.MD @@ -101,6 +101,7 @@ if use_mkldnn 5. 在**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue`和`mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。 6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。 7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。 +8. 关于MKLDNN参数的保存。由于MKLDNN参数的格式与PaddlePaddle原有的格式存在不一样的情况,所以需要在保存参数时同时保存该格式信息。目前准备扩展[Header](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/Parameter.h#L247)里面的`int32_t version`。这个值不管是在v1还是在v2里面,一直保存的是0,所以可以充分利用这个信息,定义一个枚举处理所有MKLDNN的参数格式,从而`MKLDNNLayer`就可以从输入的参数中获取需要的格式信息。 ## References diff --git a/doc/getstarted/build_and_install/build_from_source_en.md b/doc/getstarted/build_and_install/build_from_source_en.md index c0608ede8e57b224dae4b3d510d704a8b0918b53..2f1461489495618718d5abaeab9cbeda9b93700f 100644 --- a/doc/getstarted/build_and_install/build_from_source_en.md +++ b/doc/getstarted/build_and_install/build_from_source_en.md @@ -68,7 +68,7 @@ As a simple example, consider the following: 1. **BLAS Dependencies(optional)** - CMake will search BLAS libraries from system. If not found, OpenBLAS will be downloaded, built and installed automatically. + CMake will search BLAS libraries from the system. If not found, OpenBLAS will be downloaded, built and installed automatically. To utilize preinstalled BLAS, you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`. ```bash @@ -131,9 +131,9 @@ As a simple example, consider the following: To build GPU version, you will need the following installed: 1. a CUDA-capable GPU - 2. A supported version of Linux with a gcc compiler and toolchain + 2. A supported version of Linux with a GCC compiler and toolchain 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) - 4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) + 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn) The CUDA development environment relies on tight integration with the host development environment, including the host compiler and C runtime libraries, and is therefore only supported on @@ -172,6 +172,7 @@ export PATH=/bin:$PATH # install PaddlePaddle Python modules. sudo pip install /opt/paddle/share/wheels/*.whl ``` + ## Build on Centos 7 ### Install Dependencies @@ -192,9 +193,9 @@ sudo pip install /opt/paddle/share/wheels/*.whl To build GPU version, you will need the following installed: 1. a CUDA-capable GPU - 2. A supported version of Linux with a gcc compiler and toolchain + 2. A supported version of Linux with a GCC compiler and toolchain 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) - 4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) + 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn) The CUDA development environment relies on tight integration with the host development environment, including the host compiler and C runtime libraries, and is therefore only supported on @@ -222,7 +223,7 @@ mkdir build && cd build ``` Finally, you can build and install PaddlePaddle: - + ```bash # you can add build option here, such as: cmake3 .. -DCMAKE_INSTALL_PREFIX= diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp index b3287552db87d25edbf6e7f3d5e68121df49e9d6..629449bbd497a7444144c533ad079b3ae6b51438 100644 --- a/paddle/capi/gradient_machine.cpp +++ b/paddle/capi/gradient_machine.cpp @@ -146,3 +146,19 @@ paddle_error paddle_gradient_machine_randomize_param( m->machine->randParameters(); return kPD_NO_ERROR; } + +paddle_error paddle_gradient_machine_get_layer_output( + paddle_gradient_machine machine, + const char* layerName, + paddle_arguments args) { + auto m = cast(machine); + auto out = paddle::capi::cast(args); + if (m == nullptr || layerName == nullptr || out == nullptr || + m->machine == nullptr) { + return kPD_NULLPTR; + } + + auto layerOutput = m->machine->getLayerOutput(layerName); + out->args.push_back(layerOutput); + return kPD_NO_ERROR; +} diff --git a/paddle/capi/gradient_machine.h b/paddle/capi/gradient_machine.h index c613ade5b24efbbf52f21c7ee86dd3189981c5ef..28eeb23e3bbdd4cc22a25c14170bf56c294f8cd7 100644 --- a/paddle/capi/gradient_machine.h +++ b/paddle/capi/gradient_machine.h @@ -39,7 +39,11 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference( /** * @brief Create a gradient machine used for model inference, using config with * parameters which is generated by `paddle merge_model`. - * @param [out] machine that used for model inference. + * Example: + * paddle merge_model \ + * --model_dir="pass-00000" \ + * --model_file="merged_model.paddle" + * @param [out] machine that used for model inference * @param [in] mergedModel * @param [in] size * @return paddle_error @@ -97,6 +101,18 @@ paddle_gradient_machine_randomize_param(paddle_gradient_machine machine); PD_API paddle_error paddle_gradient_machine_destroy(paddle_gradient_machine machine); +/** + * @brief Get the output of the layer named `layerName`. + * @param [in] gradient machine that have run a inference + * @param [in] layerName name of specified layer + * @param [out] args output of the specified layer + * @return paddle_error + */ +PD_API paddle_error +paddle_gradient_machine_get_layer_output(paddle_gradient_machine machine, + const char* layerName, + paddle_arguments args); + #ifdef __cplusplus } #endif diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9024ed2fd427d7d21d7899a3b8df61d86f08a2cd..68304c9fc8b8fa13cb1f99b82517abc87c71496c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc) cc_test(scope_test SRCS scope_test.cc DEPS scope) -proto_library(attribute_proto SRCS attribute.proto) -proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto) -proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto) -cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) -cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) +proto_library(framework_proto SRCS framework.proto) -cc_library(attribute SRCS attribute.cc DEPS op_desc op_proto) +cc_library(attribute SRCS attribute.cc DEPS framework_proto) -cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute) +cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator) -cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder) +cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) +cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) -py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.proto) +py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) @@ -42,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) cc_library(backward SRCS backward.cc DEPS net_op) -cc_test(backward_test SRCS backward_test.cc DEPS backward) +cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) if(WITH_PYTHON) cc_library(paddle_pybind SHARED diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 4c5790693b7e48396e945d09f4fdc72b86aa5978..9eb07acdff1d00dd926f1cee9c24f9f151006d7e 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -44,7 +44,7 @@ AttrType AttrTypeID>() { return STRINGS; } -Attribute GetAttrValue(const AttrDesc& attr_desc) { +Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case paddle::framework::AttrType::INT: { return attr_desc.i(); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 49a62bedb6aadab5ff05d8aa7dda42fe983314a0..08b47cabd4c2225c50022bd35734dcc2663324d6 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -20,8 +20,7 @@ limitations under the License. */ #include #include -#include "paddle/framework/attribute.pb.h" -#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/framework.pb.h" #include "paddle/platform/enforce.h" #include "paddle/platform/variant.h" @@ -37,7 +36,7 @@ typedef std::unordered_map AttributeMap; template AttrType AttrTypeID(); -Attribute GetAttrValue(const AttrDesc& attr_desc); +Attribute GetAttrValue(const OpDesc::Attr& attr_desc); // check whether a value(attribute) fit a certain limit template diff --git a/paddle/framework/attribute.proto b/paddle/framework/attribute.proto deleted file mode 100644 index 13ae312c10e934566384b8bd0f41dacd6c01fc2f..0000000000000000000000000000000000000000 --- a/paddle/framework/attribute.proto +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -syntax = "proto2"; -package paddle.framework; - -// Attribute Type for paddle's Op. -// Op contains many attributes. Each type of attributes could be different. -// The AttrType will be shared between AttrDesc and AttrProto. -enum AttrType { - INT = 0; - FLOAT = 1; - STRING = 2; - INTS = 3; - FLOATS = 4; - STRINGS = 5; -} \ No newline at end of file diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 437a44a8aafa650d654a1a77c60613abe07679fe..bfda18724cc8ed23a40e0626ff07a290d26aa9d2 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -15,31 +15,44 @@ #include "paddle/framework/backward.h" #include +#include + #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/recurrent_op.h" namespace paddle { namespace framework { -static bool AllInSet(const std::vector& names, - const std::string& suffix, - const std::unordered_set& set) { +template +static void ForEachVarName(const Map& names, T callback) { for (auto& name : names) { - if (set.find(name + suffix) == set.end()) { - return false; + for (auto& n : name.second) { + if (callback(n)) return; } } - return true; } -static std::shared_ptr NOP() { - auto net_op = std::make_shared(); - net_op->type_ = "@NOP@"; +// return whether all the names + suffixes in the set +static bool AllInSet( + const std::map>& names, + const std::string& suffix, const std::unordered_set& set) { + bool all_in_set = true; + ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) { + all_in_set = set.find(n + suffix) != set.end(); + return !all_in_set; + }); + return all_in_set; +} + +static std::unique_ptr NOP() { + auto net_op = new operators::NetOp(); + net_op->SetType("@NOP@"); net_op->CompleteAddOp(); - return net_op; + return std::unique_ptr(net_op); } -// Get backward operator from a forward operator, recursively implementation. +// Get backward operator from a forward operator, a recursive implementation. // // no_grad_names the gradient variable names without gradient calculating. // @@ -47,122 +60,152 @@ static std::shared_ptr NOP() { // BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and // pass `uniq_id` through recursive calling. // -// returns The backward operator. For simple situation, it is a simple -// operator. For complex situation, it is a NetOp. +// returns The backward operator. In a simple situation, it may be a simple +// operator, in a complex situation, it maybe a NetOp. // // See Backward.h for details -static std::shared_ptr BackwardRecursive( - const OperatorBase& forwardOp, - std::unordered_set& no_grad_names, size_t& uniq_id); -std::shared_ptr BackwardRecursive( +static std::unique_ptr BackwardRecursive( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { // If all input gradients of forwarding operator do not need to calculate, // just return an NOP. Not return null ptr because NOP does not take // too much time for calculation, but it is useful for simplifying logic. - if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { + if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/, + no_grad_names /*set*/)) { return NOP(); } // All output gradients of forwarding operator do not need to calculate. // Then all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP. - if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { - for (auto& name : forwardOp.inputs_) { - // Mark all input is not need - no_grad_names.insert(name + kGradVarSuffix); - } + if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/, + no_grad_names /*set*/)) { + ForEachVarName(forwardOp.Inputs(), + [&no_grad_names](const std::string& name) -> bool { + no_grad_names.insert(GradVarName(name)); + return false; + }); return NOP(); } // Returned gradient network - auto net = std::make_shared(); + auto net = std::unique_ptr(new operators::NetOp()); if (forwardOp.IsNetOp()) { // Because forwardOp is a net op, it can static_cast. auto& forwardNet = static_cast(forwardOp); // Map from output gradient variable name to operator's indices in - // backward net. That operator generates that variable. + // backward net's ops_. That operator generates that variable. std::unordered_map> dup_output_ops; size_t local_op_id = 0; - // reversely travel forwardNet + // reversely travel forwardNet and collect all duplicate outputs. for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it, ++local_op_id) { - auto fwd = *it; + auto& fwd = *it; auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); - net->AddOp(bwd); - for (auto& out : bwd->outputs_) { - dup_output_ops[out].emplace_back(local_op_id); - } + ForEachVarName(bwd->Outputs(), + [&dup_output_ops, local_op_id](const std::string& out) { + dup_output_ops[out].emplace_back(local_op_id); + return false; + }); + net->AppendOp(std::move(bwd)); } // Get unique ID for this method. auto uid = uniq_id++; // TODO(dzh): more comment - using Pos = std::pair>; + // multiple operators which have the same output (y for example) may + // overwrite the same y variable when backward, special operations are token + // to handle this case. For each duplicate output, rename it to an alias + // (original name with a offset), append an `add` op for its operator, + // and finally sum all the alias variable to the final output variable y. + using Pos = std::pair>; std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; auto& dup_op = dup_output_op.second; + // no duplicate output if (dup_op.size() == 1) continue; - std::vector dup_outputs; + // process the duplicate outputs + std::vector dup_outputs; for (size_t i = 0; i < dup_op.size(); ++i) { + // rename each duplicate output to an alias auto op_offset = dup_op[i]; dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back()); } + // collect all the offset to append `add` op for each alias insert_position.push_back( - {dup_op.back(), - OpRegistry::CreateOp( - "add", {dup_outputs}, {name}, - {{"input_format", - std::vector{0, static_cast(dup_outputs.size())}}})}); + {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}}, + {{"Out", {name}}}, {})}); } + // make sure the inserted `add` ops follow the BFS order. insert_position.sort( [](const Pos& l, const Pos& r) { return l.first > r.first; }); for (auto& pos : insert_position) { - net->InsertOp(pos.first + 1, pos.second); + net->InsertOp(pos.first + 1, std::move(pos.second)); } - } else { - std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); - for (std::string& grad_input : grad_op->inputs_) { + std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); + + ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( + const std::string& grad_input) { if (no_grad_names.count(grad_input)) { // +1 for \0 std::string prefix = grad_input.substr( 0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); - grad_input = prefix + kZeroVarSuffix; + grad_op->Rename(grad_input, prefix + kZeroVarSuffix); // If part of input gradient of that operator is not calculated, fill // zero variables to that input gradient. - net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, - {grad_input}, {})); - } - } - - for (std::string& grad_output : grad_op->outputs_) { - if (no_grad_names.count(grad_output)) { - grad_output = kEmptyVarName; + net->AppendOp(OpRegistry::CreateOp("fill_zeros_like", + {{"Src", {prefix}}}, + {{"Dst", {grad_input}}}, {})); } + return false; + }); + + ForEachVarName(grad_op->Outputs(), + [&no_grad_names, &grad_op](const std::string& grad_output) { + if (no_grad_names.count(grad_output)) { + grad_op->Rename(grad_output, kEmptyVarName); + } + return false; + }); + + // process recurrent gradient op as a special operator. + if (forwardOp.Type() == "recurrent_op") { + // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or + // this will result in infinite loop. + const auto& rnnop = + *static_cast(&forwardOp); + auto rnn_grad_op = + static_cast(grad_op.get()); + const auto& stepnet_op = + *static_cast(&rnnop.stepnet()); + // create stepnet's gradient op + rnn_grad_op->set_stepnet( + BackwardRecursive(stepnet_op, no_grad_names, uniq_id)); } if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } - net->AddOp(grad_op); + net->AppendOp(std::move(grad_op)); } - net->type_ = "@GENERATED_BACKWARD@"; + net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); - return net; + return std::unique_ptr( + static_cast(net.release())); } // See header for comments -std::shared_ptr Backward( +std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars) { std::unordered_set no_grad_names; diff --git a/paddle/framework/backward.h b/paddle/framework/backward.h index c181919dc165cf0b49362f85e22ceb4131bbd387..1ecf69881b3126c2904920b9f4b77bfcccc9cf86 100644 --- a/paddle/framework/backward.h +++ b/paddle/framework/backward.h @@ -20,7 +20,7 @@ namespace framework { // Create the backward operator from a forward operator. // TODO(yuyang18): Add more API reference comment. -extern std::shared_ptr Backward( +extern std::unique_ptr Backward( const OperatorBase& forwardOp, const std::unordered_set& no_grad_vars); } // namespace framework diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index da3b9c8bed7cd123f2f8ef982a5f0e23abcc0ec7..b93ab66f2f5b9cffa6d51b6e36afe552125970e4 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -28,21 +28,13 @@ using OpAttrChecker = framework::OpAttrChecker; using Scope = framework::Scope; using DeviceContext = platform::DeviceContext; -class EmptyOp : public OperatorBase { - public: - DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase) - - void InferShape(const Scope &scope) const override {} - void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} -}; - class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").IgnoreGradient(); - AddInput("b", "Bias of Add").IgnoreGradient(); - AddOutput("Out", "Out of Add").IgnoreGradient(); + AddInput("X", "Input X of Add").NotInGradient(); + AddInput("b", "Bias of Add").NotInGradient(); + AddOutput("Out", "Out of Add").NotInGradient(); AddComment("Add Op"); } }; @@ -51,8 +43,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker { public: MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("A", "A"); - AddInput("B", "B"); + AddInput("X", "A"); + AddInput("Y", "B"); AddOutput("Out", "Out"); AddComment("Mul"); } @@ -63,7 +55,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker { SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "X"); - AddOutput("Y", "Y"); + AddOutput("Out", "Y"); AddComment("Sigmoid"); } }; @@ -73,21 +65,25 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "X input"); - AddOutput("Y", "Y output"); + AddOutput("Out", "Y output"); AddComment("NoGradOp, same input output. no Grad"); } }; class FcOp : public operators::NetOp { public: - void Init() override { - AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, - {Output("mul_result")}, {})); - auto b_name = Input("b"); + FcOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + AppendOp(OpRegistry::CreateOp("mul", + {{"X", {Input("X")}}, {"Y", {Input("W")}}}, + {{"Out", {Output("mul_result")}}}, {})); + auto input_b = Inputs("b"); std::string before_act = "mul_result"; - if (b_name != kEmptyVarName) { - AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name}, - {Output("add_result")}, {})); + if (input_b.size() != 0) { + AppendOp(OpRegistry::CreateOp( + "rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}}, + {{"Out", {Output("add_result")}}}, {})); before_act = "add_result"; } else { auto out_varname = Output("add_result"); @@ -96,8 +92,8 @@ class FcOp : public operators::NetOp { } } - AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")}, - {})); + AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}}, + {{"Out", {Output("Out")}}}, {})); CompleteAddOp(false); } }; @@ -109,8 +105,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { AddInput("X", "x"); AddInput("W", "w"); AddInput("b", "b"); - AddOutput("mul_result", "").SetTemporary(); - AddOutput("add_result", "").SetTemporary(); + AddOutput("mul_result", "").AsIntermediate(); + AddOutput("add_result", "").AsIntermediate(); AddOutput("Out", ""); AddComment(""); } @@ -141,7 +137,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker { public: AddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x").SetMultiple(); + AddInput("X", "x").AsDuplicable(); AddOutput("Y", "y"); AddComment(""); } @@ -152,51 +148,48 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); -REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); -REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); -REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); -REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); -REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); -REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); -REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); -REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); -REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); -REGISTER_OP(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); -REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); +REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, + f::NOP); +REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); +REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); +REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); +REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); +REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, + f::NOP); TEST(Backward, simple_op_grad) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto fwd = f::OpRegistry::CreateOp( + "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(4UL, gop->inputs_.size()); - ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]); - ASSERT_EQ("rowwise_add_grad", gop->type_); - ASSERT_EQ(f::GradVarName("X"), gop->outputs_[0]); - ASSERT_EQ(f::GradVarName("b"), gop->outputs_[1]); - - ASSERT_EQ(f::GradVarName("X"), gop->Output(f::GradVarName("X"))); + ASSERT_EQ(1UL, gop->Inputs().size()); + ASSERT_EQ("rowwise_add_grad", gop->Type()); + ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); + ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); } TEST(Backward, simple_op_not_need_grad) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto fwd = f::OpRegistry::CreateOp( + "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); ASSERT_NE(fwd, nullptr); - auto gop = f::Backward(*fwd, {"X"}); - ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(), - f::GradVarName("X")), - gop->outputs_.end()); + auto gop = f::Backward(*fwd, {"x"}); + ASSERT_EQ(gop->Output(f::GradVarName("X")), f::kEmptyVarName); - auto no_input_gop = f::Backward(*fwd, {"X", "b"}); + auto no_input_gop = f::Backward(*fwd, {"x", "b"}); ASSERT_NE(no_input_gop, nullptr); ASSERT_TRUE(no_input_gop->IsNetOp()); - ASSERT_EQ(0UL, - std::static_pointer_cast(no_input_gop)->ops_.size()); + ASSERT_EQ(0UL, static_cast(no_input_gop.get())->ops_.size()); } TEST(Backward, net_fc_backward_normal) { - std::shared_ptr fwd = f::OpRegistry::CreateOp( - "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); + std::shared_ptr fwd = + f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}}, + {{"mul_result", {"mul_res"}}, + {"add_result", {"add_re"}}, + {"Out", {"out"}}}, + {}); ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); @@ -207,19 +200,22 @@ TEST(Backward, net_fc_backward_normal) { ASSERT_EQ(3UL, net->ops_.size()); f::OperatorBase &d_sigmoid = *net->ops_[0]; - ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + ASSERT_EQ("sigmoid_grad", d_sigmoid.Type()); f::OperatorBase &d_add = *net->ops_[1]; - ASSERT_EQ("rowwise_add_grad", d_add.type_); + ASSERT_EQ("rowwise_add_grad", d_add.Type()); f::OperatorBase &d_mul = *net->ops_[2]; - ASSERT_EQ("mul_grad", d_mul.type_); + ASSERT_EQ("mul_grad", d_mul.Type()); } TEST(Backward, net_fc_backward_not_have_b) { std::shared_ptr fwd = - f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName}, - {"mul_result", "add_result", "tmp"}, {}); + f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {}}}, + {{"mul_result", {"mul_res"}}, + {"add_result", {"add_res"}}, + {"Out", {"tmp"}}}, + {}); ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); @@ -230,96 +226,113 @@ TEST(Backward, net_fc_backward_not_have_b) { ASSERT_EQ(2UL, net->ops_.size()); f::OperatorBase &d_sigmoid = *net->ops_[0]; - ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + ASSERT_EQ("sigmoid_grad", d_sigmoid.Type()); f::OperatorBase &d_mul = *net->ops_[1]; - ASSERT_EQ("mul_grad", d_mul.type_); + ASSERT_EQ("mul_grad", d_mul.Type()); } TEST(Backward, net_input_of_network_not_need_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, - {"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, - {"mul_tmp_1", "add_tmp_1", "hidden1"}, {})); + net.AppendOp(f::OpRegistry::CreateOp( + "fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}}, + {{"mul_result", {"mul_tmp_0"}}, + {"add_result", {"add_tmp_0"}}, + {"Out", {"hidden0"}}}, + {})); + net.AppendOp(f::OpRegistry::CreateOp( + "fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}}, + {{"mul_result", {"mul_tmp_1"}}, + {"add_result", {"add_tmp_1"}}, + {"Out", {"hidden1"}}}, + {})); net.CompleteAddOp(); - auto bwd = Backward(net, {"X"}); // X@GRAD is not need. + auto bwd = Backward(net, {"x"}); // x@GRAD is not need. ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); - std::unordered_set all_output = std::unordered_set( - bwd_net->outputs_.begin(), bwd_net->outputs_.end()); - all_output.erase(f::kEmptyVarName); + auto output_vars = bwd_net->OutputVars(true); + std::unordered_set all_outputs = + std::unordered_set(output_vars.begin(), output_vars.end()); + all_outputs.erase(f::kEmptyVarName); for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { - ASSERT_NE(all_output.find(f::GradVarName(out)), all_output.end()); + ASSERT_NE(all_outputs.find(f::GradVarName(out)), all_outputs.end()); } // Not Generated X - ASSERT_EQ(all_output.find(f::GradVarName("X")), all_output.end()); + ASSERT_EQ(all_outputs.find(f::GradVarName("X")), all_outputs.end()); ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(f::kEmptyVarName, - first_fc_grad->ops_[2]->Output(f::GradVarName("A"))); + first_fc_grad->ops_[2]->Output(f::GradVarName("X"))); } TEST(Backward, net_shared_weight) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); - net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}}, + {{"Out", {"out"}}}, {})); + net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}}, + {{"Out", {"FinalOut"}}}, {})); net.CompleteAddOp(); auto bwd = f::Backward(net, {}); ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - ASSERT_EQ("add", bwd_net->ops_[2]->type_); + ASSERT_EQ("add", bwd_net->ops_[2]->Type()); } TEST(Backward, op_register_grad_not_for_network) { - auto fwd = f::OpRegistry::CreateOp( - "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"}, - {{"temporary_index", std::vector{0, 1}}}); + auto fwd = + f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}}, + {{"mul_result", {"mul_out"}}, + {"add_result", {"add_out"}}, + {"Out", {"out1"}}}, + {{"temporary_index", std::vector{0, 1}}}); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); } TEST(Backward, op_all_input_are_not_need) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); - auto backward = f::Backward(*fwd, {"X", "b"}); + auto fwd = f::OpRegistry::CreateOp( + "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); + auto backward = f::Backward(*fwd, {"x", "b"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); ASSERT_TRUE(net->ops_.empty()); } TEST(Backward, op_all_output_are_not_need) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); - auto backward = f::Backward(*fwd, {"Out"}); + auto fwd = f::OpRegistry::CreateOp( + "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); + auto backward = f::Backward(*fwd, {"out"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); ASSERT_TRUE(net->ops_.empty()); } TEST(Backward, op_part_of_output_are_not_need) { - auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); + auto fwd = f::OpRegistry::CreateOp("many_output_op", {{"x", {"X"}}}, + {{"y", {"Y"}}, {"z", {"Z"}}}, {}); auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; - ASSERT_EQ("fill_zeros_like", fill_zero.type_); - ASSERT_EQ(1UL, fill_zero.inputs_.size()); - ASSERT_EQ("Z", fill_zero.inputs_[0]); - ASSERT_EQ(1UL, fill_zero.outputs_.size()); - ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.outputs_[0]); + ASSERT_EQ("fill_zeros_like", fill_zero.Type()); + ASSERT_EQ(1UL, fill_zero.Inputs("Src").size()); + ASSERT_EQ("Z", fill_zero.Input("Src")); + ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size()); + ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst")); auto &d_many_out = *net->ops_[1]; - ASSERT_EQ("many_output_op_grad", d_many_out.type_); - ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("many_output_op_grad", d_many_out.Type()); + ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size()); // I/O/OG ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, d_many_out.Input(f::GradVarName("z"))); ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); @@ -327,44 +340,62 @@ TEST(Backward, op_part_of_output_are_not_need) { } TEST(Backward, op_part_of_input_are_not_need) { - auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto fwd = f::OpRegistry::CreateOp("mul", {{"X", {"a"}}, {"Y", {"b"}}}, + {{"Out", {"out"}}}, {}); auto backward = f::Backward(*fwd, {"a"}); auto &grad_mul = *backward; - ASSERT_EQ(grad_mul.type_, "mul_grad"); - ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); - ASSERT_EQ(grad_mul.outputs_.size(), 2UL); - ASSERT_EQ(grad_mul.Output(f::GradVarName("A")), f::kEmptyVarName); - ASSERT_EQ(grad_mul.Output(f::GradVarName("B")), f::GradVarName("b")); + ASSERT_EQ(grad_mul.Type(), "mul_grad"); + ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.Outputs().size(), 2UL); + ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName); + ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b")); ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); - ASSERT_EQ(grad_mul.Input("A"), "a"); - ASSERT_EQ(grad_mul.Input("B"), "b"); + ASSERT_EQ(grad_mul.Input("X"), "a"); + ASSERT_EQ(grad_mul.Input("Y"), "b"); ASSERT_EQ(grad_mul.Input("Out"), "out"); } TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ops::NetOp net; - net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, - {"mul_out1", "add_out1", "out1"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, - {"mul_out2", "tmp_out2", "out2"}, {})); - net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, - {"mul_out3", "tmp_out3", "out3"}, {})); + net.AppendOp(f::OpRegistry::CreateOp( + "fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}}, + {{"mul_result", {"mul_out1"}}, + {"add_result", {"add_out1"}}, + {"Out", {"out1"}}}, + {})); + net.AppendOp(f::OpRegistry::CreateOp( + "fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}}, + {{"mul_result", {"mul_out2"}}, + {"add_result", {"tmp_out2"}}, + {"Out", {"out2"}}}, + {})); + net.AppendOp(f::OpRegistry::CreateOp( + "fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}}, + {{"mul_result", {"mul_out3"}}, + {"add_result", {"tmp_out3"}}, + {"Out", {"out3"}}}, + {})); net.CompleteAddOp(); + auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); ASSERT_TRUE(backward->IsNetOp()); auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 3UL); auto &grad_fc = *bwd_net->ops_[0]; - EXPECT_EQ(grad_fc.inputs_.size(), - 3UL /* external input number */ + + const char *all = paddle::operators::NetOp::kAll; + EXPECT_EQ(grad_fc.Inputs(all).size(), + 2UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); - EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ - + 2UL /* input number of rowwise_add */ - + 1UL /* input number of sigmod */); - EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL); - EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL); + EXPECT_EQ(grad_fc.Outputs(all).size(), + 2UL /* input number of mul*/ + + 2UL /* input number of rowwise_add + */ + + 1UL /* input number of sigmod */); + EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); } diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 545c1dcc2a1682839d90194002fdbb748d85e808..cfd3e8dfdec0e92620aef5cd246b4622b779ce19 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -283,6 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } - } // namespace framework } // namespace paddle diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto new file mode 100644 index 0000000000000000000000000000000000000000..ae44a1ffd45dacdc44a72edc630e771e7a2f2990 --- /dev/null +++ b/paddle/framework/framework.proto @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto2"; +package paddle.framework; + +enum AttrType { + INT = 0; + FLOAT = 1; + STRING = 2; + INTS = 3; + FLOATS = 4; + STRINGS = 5; +} + +// OpDesc describes an instance of a C++ framework::OperatorBase +// derived class type. +message OpDesc { + + message Attr { + required string name = 1; + required AttrType type = 2; + optional int32 i = 3; + optional float f = 4; + optional string s = 5; + repeated int32 ints = 6; + repeated float floats = 7; + repeated string strings = 8; + }; + + message Var { + required string parameter = 1; + repeated string arguments = 2; + }; + + required string type = 3; + repeated Var inputs = 1; + repeated Var outputs = 2; + repeated Attr attrs = 4; +}; + +// OpProto describes a C++ framework::OperatorBase derived class. +message OpProto { + + // VarProto describes the C++ type framework::Variable. + message Var { + required string name = 1; + required string comment = 2; + + optional bool duplicable = 3 [ default = false ]; + optional bool intermediate = 4 [ default = false ]; + optional bool not_in_gradient = 5 [ default = false ]; + } + + // AttrProto describes the C++ type Attribute. + message Attr { + required string name = 1; + required AttrType type = 2; + required string comment = 3; + // If that attribute is generated, it means the Paddle third + // language binding has responsibility to fill that + // attribute. End-User should not set that attribute. + optional bool generated = 4 [ default = false ]; + } + + required string type = 1; + repeated Var inputs = 2; + repeated Var outputs = 3; + repeated Attr attrs = 4; + required string comment = 5; +} diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 8bd2bc590272256fed79f4ab38ad52b470e87012..0a2a41f6b62658ac8633a6e384d099f8d6641f33 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -13,105 +13,53 @@ express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { - -typedef std::vector Ints; - enum class OpArgType { IN, OUT }; -const Ints* AttrFormat(const AttributeMap& attrs, const std::string& key) { - return (attrs.count(key) > 0) ? &boost::get(attrs.at(key)) : nullptr; -} - -Ints* AttrFormat(AttributeMap& attrs, const std::string& key) { - return (attrs.count(key) > 0) ? &boost::get(attrs.at(key)) : nullptr; -} - -static void TransOpArg(const OperatorBase* src_op, - std::vector& grad_inputs, - std::vector& grad_outputs, - AttributeMap& grad_attrs, - std::unordered_map& grad_idxs, - const std::string& src_type, const std::string& dst_type, - int& idx, bool is_grad) { - const std::vector& src_inout = - (src_type == "input_format") ? src_op->inputs_ : src_op->outputs_; - - const std::vector* src_format = AttrFormat(src_op->Attrs(), src_type); - - std::vector& dst_inout = - (dst_type == "input_format") ? grad_inputs : grad_outputs; - - std::vector* dst_format = AttrFormat(grad_attrs, dst_type); - - const OpProto& proto = OpRegistry::protos().at(src_op->type_); - +static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, + bool is_grad, OperatorBase::VarNameMap* vars) { + const auto& src_inout = + src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); + auto& dst_inout = *vars; + const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_; const auto& src_arg_list = - (src_type == "input_format") ? proto.inputs() : proto.outputs(); - + src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { - std::string src_name = arg.name(); - std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name; - grad_idxs[dst_name] = idx++; - int src_arg_idx = src_op->in_out_idxs_->at(src_name); - int src_begin = - src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); - int src_end = src_format == nullptr ? src_arg_idx + 1 - : src_format->at(src_arg_idx + 1); - for (int i = src_begin; i < src_end; ++i) { - std::string s = - is_grad ? src_inout[i] + kGradVarSuffix - : (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]); - dst_inout.emplace_back(s); - } - if (dst_format != nullptr) { - dst_format->push_back(dst_inout.size()); + if (arg.not_in_gradient() && !is_grad) continue; + const std::string src_name = arg.name(); + std::string dst_name = is_grad ? GradVarName(src_name) : src_name; + dst_inout[dst_name].reserve(src_inout.at(src_name).size()); + for (auto& var_name : src_inout.at(src_name)) { + std::string s = is_grad ? GradVarName(var_name) : var_name; + dst_inout[dst_name].emplace_back(s); } } } OperatorBase* BuildGradOp(const OperatorBase* op) { - const std::string& grad_op_type = OpRegistry::grad_ops().at(op->Type()); - - AttributeMap grad_attrs(op->Attrs()); - grad_attrs.erase("input_format"); - grad_attrs.erase("output_format"); - if (op->Attrs().count("input_format") > 0) { - grad_attrs["output_format"] = std::vector({0}); - } - if (op->Attrs().count("input_format") > 0 || - op->Attrs().count("output_format") > 0) { - grad_attrs["input_format"] = std::vector({0}); - } - - std::vector grad_inputs, grad_outputs; - - using VarIndexMap = std::unordered_map; - VarIndexMap* grad_idxs = new VarIndexMap; - int in_idx = 0; - int out_idx = 0; - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, - "input_format", "input_format", in_idx, false); // I - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, - "output_format", "input_format", in_idx, false); // G - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, - "output_format", "input_format", in_idx, true); // OG - TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs, - "input_format", "output_format", out_idx, true); // IG - - OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); - - grad_op->type_ = grad_op_type; - grad_op->inputs_ = grad_inputs; - grad_op->outputs_ = grad_outputs; - grad_op->attrs_ = grad_attrs; - grad_op->in_out_idxs_.reset(grad_idxs); - - return grad_op; + auto it = OpRegistry::op_info_map().find(op->Type()); + PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), + "'%s' has not been registered.", op->Type()); + PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", + op->Type()); + std::string grad_op_type = it->second.grad_op_type_; + PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", + op->Type()); + + OperatorBase::VarNameMap inputs; + OperatorBase::VarNameMap outputs; + TransOpArg(op, OpArgType::IN, false, &inputs); // I + TransOpArg(op, OpArgType::OUT, false, &inputs); // O + TransOpArg(op, OpArgType::OUT, true, &inputs); // OG + TransOpArg(op, OpArgType::IN, true, &outputs); // IG + + it = OpRegistry::op_info_map().find(grad_op_type); + PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), + "'%s' has not been registered.", grad_op_type); + return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs()); } } // namespace framework diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 19e552b7458c966d473bdee99515a2beee1f6089..902c2655e9182d74a48ad13e17a39a3304d5fa57 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -8,24 +8,15 @@ USE_OP(add_two); namespace paddle { namespace framework { -class NOP : public OperatorBase { - public: - DEFINE_OPERATOR_CTOR(NOP, OperatorBase) - - void InferShape(const Scope &scope) const override {} - void Run(const Scope &scope, - const platform::DeviceContext &dev_ctx) const override {} -}; - class MutiInOutOpMaker : public OpProtoAndCheckerMaker { public: MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").SetMultiple(); + AddInput("In2_mult", "a multiple input").AsDuplicable(); AddInput("In3", "another single input"); AddOutput("Out1", "a single output"); - AddOutput("Out2_mult", "a multiple output").SetMultiple(); + AddOutput("Out2_mult", "a multiple output").AsDuplicable(); AddComment("test op with multiple inputs and outputs"); } }; @@ -35,10 +26,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient(); - AddInput("In3_mult", "another multiple input").SetMultiple(); - AddOutput("Out1_mult", "a multiple output").SetMultiple(); - AddOutput("Out2", "a single output").IgnoreGradient(); + AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); + AddInput("In3_mult", "another multiple input").AsDuplicable(); + AddOutput("Out1_mult", "a multiple output").AsDuplicable(); + AddOutput("Out2", "a single output").NotInGradient(); AddComment("op with inputs and outputs ignored in gradient calculating"); } }; @@ -49,35 +40,33 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; TEST(GradOpBuilder, AddTwo) { - std::shared_ptr add_op( - f::OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + std::shared_ptr add_op(f::OpRegistry::CreateOp( + "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); std::shared_ptr grad_add_op = f::OpRegistry::CreateGradOp(*add_op); - EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); - EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); + EXPECT_EQ(grad_add_op->Inputs().size(), 4UL); + EXPECT_EQ(grad_add_op->Outputs().size(), 2UL); EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Out"), "out"); - EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD"); - EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD"); - EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); + EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out")); + EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x")); + EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); -REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); -REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { - f::AttributeMap attrs{{"input_format", std::vector{0, 1, 4, 5}}, - {"output_format", std::vector{0, 1, 3}}}; std::shared_ptr test_op(f::OpRegistry::CreateOp( - "mult_io", {"in1", "in2_1", "in2_2", "in2_3", "in3"}, - {"out1", "out2_1", "out2_2"}, attrs)); + "mult_io", {{"In1", {"in1"}}, + {"In2_mult", {"in2_1", "in2_2", "in2_3"}}, + {"In3", {"in3"}}}, + {{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {})); std::shared_ptr grad_test_op = f::OpRegistry::CreateGradOp(*test_op); - ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Inputs("In2_mult"), std::vector({"in2_1", "in2_2", "in2_3"})); @@ -91,7 +80,7 @@ TEST(GradOpBuilder, MutiInOut) { std::vector( {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), std::vector({f::GradVarName("in2_1"), @@ -101,31 +90,28 @@ TEST(GradOpBuilder, MutiInOut) { } TEST(GradOpBuilder, IOIgnoredInGradient) { - f::AttributeMap attrs{{"input_format", std::vector{0, 1, 3, 5}}, - {"output_format", std::vector{0, 2, 3}}}; std::shared_ptr test_op(f::OpRegistry::CreateOp( - "io_ignored", {"in1", "in2_1", "in2_2", "in3_1", "in3_2"}, - {"out1_1", "out1_2", "out2"}, attrs)); + "io_ignored", {{"In1", {"in1"}}, + {"In2_mult", {"in2_1", "in2_2"}}, + {"In3_mult", {"in3_1", "in3_2"}}}, + {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {})); std::shared_ptr grad_test_op = f::OpRegistry::CreateGradOp(*test_op); // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In2_mult"), - std::vector({f::kEmptyVarName, f::kEmptyVarName})); EXPECT_EQ(grad_test_op->Inputs("In3_mult"), std::vector({"in3_1", "in3_2"})); EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName); EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), std::vector( {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), f::GradVarName("out2")); - ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), std::vector( diff --git a/paddle/framework/op_desc.proto b/paddle/framework/op_desc.proto deleted file mode 100644 index d95ba26f88ae181f991440e0df30c80f80a7eb2a..0000000000000000000000000000000000000000 --- a/paddle/framework/op_desc.proto +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -syntax = "proto2"; -package paddle.framework; - -import "attribute.proto"; - -// AttrDesc is used to describe Attributes of an Operator. It contain's -// name, type, and value of Attribute. -// -// e.g, for scale=3.0: name=scala, type=AttrType.FLOAT, value=3.0 -message AttrDesc { - required string name = 1; - required AttrType type = 2; - optional int32 i = 3; - optional float f = 4; - optional string s = 5; - repeated int32 ints = 6; - repeated float floats = 7; - repeated string strings = 8; -}; - -// Protocol Message to describe an Operator. -// -// In PaddlePaddle, Operator is used to do a certain computation such -// as "add", "sub", "cosine", etc. -// (1) Operator needs to know the input and output variable names. -// (2) Some ops may have special attributes such as "scale" in "CosineOp". -// -// 3rd-party language can build this proto message and call -// AddOp(const OpDesc& op_desc) of Paddle core to create an Operator. -message OpDesc { - // input names of this Operator. - repeated string inputs = 1; - - // output names of this Operator. - repeated string outputs = 2; - - // type of this Operator, such as "add", "sub", "fc". - required string type = 3; - - // Attributes of this Operator. e.g., scale=3.0 in cosine op. - repeated AttrDesc attrs = 4; -}; \ No newline at end of file diff --git a/paddle/framework/op_desc_test.cc b/paddle/framework/op_desc_test.cc deleted file mode 100644 index d0c52523b64725ee11c281b086f9ffed6a09e787..0000000000000000000000000000000000000000 --- a/paddle/framework/op_desc_test.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -TEST(OpDesc, Create) { - paddle::framework::OpDesc op_desc; - op_desc.set_type("add"); - op_desc.add_inputs("X"); - op_desc.add_inputs("Y"); - op_desc.add_outputs("Z"); - - auto attr = op_desc.mutable_attrs()->Add(); - attr->set_type(paddle::framework::AttrType::FLOAT); - attr->set_f(3.14); - - // required field name is not set, so IsInitialized should be false. - ASSERT_FALSE(op_desc.IsInitialized()); - - attr->set_name("add"); - // after all required fields are set, IsInitialized should be true now. - ASSERT_TRUE(op_desc.IsInitialized()); -} \ No newline at end of file diff --git a/paddle/framework/op_proto.proto b/paddle/framework/op_proto.proto deleted file mode 100644 index 52292162874b9ca207fb0d3917df41ade096b143..0000000000000000000000000000000000000000 --- a/paddle/framework/op_proto.proto +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// Protocol Message for 3rd-party language binding. -// -// Paddle Python package will use `OpProto` to generate op creation methods. -// The op creation methods take user's input and generate `OpDesc` proto -// message, -// then pass `OpDesc` to C++ side and create Op pointer. -// -syntax = "proto2"; -package paddle.framework; - -import "attribute.proto"; - -// Attribute protocol message for 3rd-party language binding. -// It will store the Op support what attribute and what type. -message AttrProto { - // Supported attribute name. e.g. `scale` for cosine op. - required string name = 1; - - // Supported attribute type. - required AttrType type = 2; - - // Supported attribute comments. It helps 3rd-party language generate - // doc-string. - required string comment = 3; - - // If that attribute is generated, it means the Paddle third language - // binding has responsibility to fill that attribute. End-User should - // not set that attribute. - optional bool generated = 4 [ default = false ]; -} - -// Input or output message for 3rd-party language binding. -// It contains parameter name and its comments. -message VarProto { - // Input or output name in that op creation function. - // e.g. `cos(a, b, output, ...)`, "a", "b", "output" are names. - required string name = 1; - - // The comment for that input. It helps 3rd-party language generate - // doc-string. - required string comment = 2; - - // Is that input/output could be a list or not. - // If so, that Op should write a attributed named `input_format` or - // `output_format`. - // - // e.g. - // If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W` - // could be multiple, so the multiple of `X` and `W` is True, and OpDesc - // will hold a attribute of them. - // - // The Op desc of same fc could be - // { - // "type": "fc", - // "input": ["X1", "X2", "W1", "W2", "b"], - // "output": "fc.out", - // "attrs" : { - // "input_format": [0, 2, 4, 5] - // } - // } - // - optional bool multiple = 3 [ default = false ]; - - // It marks that output is a temporary output. That output is not used by - // user, but used by other op internally as input. If other op is not use - // that output, it could be optimized early. - // - // Attribute temporary_index will be set in OpDesc if there is some - // outputs are temporary. - // - // output = [ "xxx.out1", "xxx.tmp", "xxx.out2"], - // attrs = { - // "temporary_index": [1] - // } - optional bool temporary = 4 [ default = false ]; - - // The gradient of operator can be ignored immediately - // e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2 - // can be ignored for the future optimized on graph. - optional bool ignore_gradient = 6; -} - -// Op protocol message for 3rd-party language binding. -// It contains all information for generating op creation method. -message OpProto { - // The input information to generate op creation method. - repeated VarProto inputs = 1; - - // The output information to generate op creation method. - repeated VarProto outputs = 2; - - // The attribute information to generate op creation method. - repeated AttrProto attrs = 3; - - // The comments for that Op. It helps 3rd-party language generate - // doc-string. The whole documentation of that Op is generated by comment, - // inputs, outputs, attrs together. - required string comment = 4; - - // The type of that Op. - required string type = 5; -} diff --git a/paddle/framework/op_proto_test.cc b/paddle/framework/op_proto_test.cc deleted file mode 100644 index 9c054bde44e77571330cbc59074705f0cfc1cfb6..0000000000000000000000000000000000000000 --- a/paddle/framework/op_proto_test.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include -#include - -TEST(TestOpProto, ALL) { - paddle::framework::OpProto proto; - { - auto ipt = proto.mutable_inputs()->Add(); - *ipt->mutable_name() = "a"; - *ipt->mutable_comment() = "the one input of cosine op"; - } - { - auto ipt = proto.mutable_inputs()->Add(); - *ipt->mutable_name() = "b"; - *ipt->mutable_comment() = "the other input of cosine op"; - } - { - auto opt = proto.mutable_outputs()->Add(); - *opt->mutable_name() = "output"; - *opt->mutable_comment() = "the output of cosine op"; - } - { - auto attr = proto.mutable_attrs()->Add(); - *attr->mutable_name() = "scale"; - attr->set_type(paddle::framework::AttrType::FLOAT); - *attr->mutable_comment() = "the scale attribute of cosine op"; - } - proto.set_type("cos"); - *proto.mutable_comment() = "cosine op, output = scale * cos(a, b)"; - - ASSERT_TRUE(proto.IsInitialized()); -} \ No newline at end of file diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 1caa02a2a1d046778f875d04eeaef957be741302..8eae86e9605da74cdc37caeb9569e7500aac2a63 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -17,5 +17,48 @@ limitations under the License. */ #include namespace paddle { -namespace framework {} // namespace framework +namespace framework { + +std::unique_ptr OpRegistry::CreateOp(const std::string& type, + const VarNameMap& inputs, + const VarNameMap& outputs, + AttributeMap attrs) { + auto it = op_info_map().find(type); + PADDLE_ENFORCE(it != op_info_map().end(), + "Operator '%s' has not been registered.", type); + it->second.checker_->Check(attrs); + auto op = it->second.creator_(type, inputs, outputs, attrs); + return std::unique_ptr(op); +} + +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { + VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); + VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); + AttributeMap attrs; + for (auto& attr : op_desc.attrs()) { + attrs[attr.name()] = GetAttrValue(attr); + } + + return CreateOp(op_desc.type(), inputs, outputs, attrs); +} + +OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_desc_vars) { + VarNameMap ret_val; + for (auto& var : op_desc_vars) { + auto& var_names = ret_val[var.parameter()]; + auto& var_names_in_proto = var.arguments(); + var_names.reserve(static_cast(var_names_in_proto.size())); + std::copy(var_names_in_proto.begin(), var_names_in_proto.end(), + std::back_inserter(var_names)); + } + return ret_val; +} + +std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { + PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); + return std::unique_ptr(BuildGradOp(&op)); +} + +} // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index cb9164eec1788c2c19176115e8687bed49d8c0b6..4c2d13d639005d2d2710c19f63988333d89bce13 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -17,320 +17,104 @@ limitations under the License. */ #include #include #include +#include #include #include #include "paddle/framework/attribute.h" +#include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/operator.h" #include "paddle/framework/scope.h" namespace paddle { namespace framework { -// this class not only make proto but also init attribute checkers. -class OpProtoAndCheckerMaker { - public: - OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) - : proto_(proto), op_checker_(op_checker) {} - - ~OpProtoAndCheckerMaker() { - PADDLE_ENFORCE(validated_, "should call Validate after build"); - } - - void Validate() { - validated_ = true; - CheckNoDuplicatedInOutAttrs(); - } - - protected: - struct VariableBuilder { - VarProto* var_; - std::function on_multiple_; - std::function on_temporary_; - - VariableBuilder& SetMultiple() { - var_->set_multiple(true); - on_multiple_(); - return *this; - } - - VariableBuilder& SetTemporary() { - PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); - var_->set_temporary(true); - on_temporary_(); - return *this; - } - - VariableBuilder& IgnoreGradient() { - var_->set_ignore_gradient(true); - return *this; - } - }; - - VariableBuilder AddInput(const std::string& name, - const std::string& comment) { - VarProto* input = proto_->add_inputs(); - input->set_name(name); - input->set_comment(comment); - return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, - nullptr}; - } - - VariableBuilder AddOutput(const std::string& name, - const std::string& comment) { - VarProto* output = proto_->add_outputs(); - output->set_name(name); - output->set_comment(comment); - return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, - [=] { this->SetHasTemporaryOutput(); }}; - } - - template - TypedAttrChecker& AddAttr(const std::string& name, - const std::string& comment, - bool generated = false) { - AttrProto* attr = proto_->add_attrs(); - attr->set_name(name); - attr->set_comment(comment); - attr->set_generated(generated); - attr->set_type(AttrTypeID()); - return op_checker_->AddAttrChecker(name); - } - - void AddComment(const std::string& comment) { proto_->set_comment(comment); } - - private: - void SetHasMultiple(const std::string& in_out, bool* flag) { - if (!*flag) { - AddAttr>(in_out + "_format", - "The multiple index of " + in_out + - "\n" - R"DOC( -This attribute is used by Paddle core framework. Paddle's Op support each input -or output could be a list of variable. This attribute is used to show how that -list organized. - -e.g. - input = ["a", "b", "c", "d", "e", "f"] - input_format = [0, 4, 5, 6] - -means - The number of all input variables this op is six, and they are segmented into - three inputs. - - The first input is input[0:4], second is input[4:5], third is input[5:6]. -)DOC", - /*generated*/ true); - *flag = true; - } - } - - void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); } - void SetHasMultipleOutput() { - SetHasMultiple("output", &has_multiple_output_); - } - - void SetHasTemporaryOutput() { - if (!has_temporary_output_) { - AddAttr>("temporary_index", - R"DOC(The temporary index of output. - -Not all output of Paddle Op is used by user. For faster computation, each op -could output some its internal state to other op, other op could take that -output to make compute faster. - -Add a mark to which output is temporary is helpful for future optimization. -)DOC", - /*generated*/ true) - .SetDefault(std::vector()); - has_temporary_output_ = true; - } - } - - void CheckNoDuplicatedInOutAttrs() { - std::unordered_set names; - auto checker = [&](const std::string& name) { - PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); - names.insert(name); - }; - for (auto& attr : proto_->attrs()) { - checker(attr.name()); - } - for (auto& input : proto_->inputs()) { - checker(input.name()); - } - for (auto& output : proto_->outputs()) { - checker(output.name()); - } - } - - OpProto* proto_; - OpAttrChecker* op_checker_; - bool validated_{false}; - bool has_multiple_input_{false}; - bool has_multiple_output_{false}; - bool has_temporary_output_{false}; -}; - class OpRegistry { - using OpCreator = std::function; - using VarIndexMap = std::unordered_map; - using VarNameList = std::vector; + using VarNameMap = OperatorBase::VarNameMap; + using OpCreator = std::function; public: - template - static void RegisterOp(const std::string& op_type) { - op_creators()[op_type] = [] { return new OpType; }; - OpAttrChecker& op_checker = op_checkers()[op_type]; - OpProto& op_proto = protos()[op_type]; - auto maker = ProtoMakerType(&op_proto, &op_checker); - maker.Validate(); - op_proto.set_type(op_type); - PADDLE_ENFORCE( - op_proto.IsInitialized(), - "Fail to initialize %s's OpProto, because %s is not initialized", - op_type, op_proto.InitializationErrorString()); - - VarIndexMaps()[op_type].reset(new VarIndexMap()); - auto& varmap = *VarIndexMaps()[op_type]; - int idx = 0; - for (auto& var : op_proto.inputs()) { - varmap[var.name()] = idx++; - } - idx = 0; - for (auto& var : op_proto.outputs()) { - varmap[var.name()] = idx++; - } - } - - template - static void RegisterGradOp(const std::string& op_type, - const std::string& grad_op_type) { - op_creators()[grad_op_type] = [] { return new GradOpType; }; - grad_ops()[op_type] = grad_op_type; - } - - static std::shared_ptr CreateOp(const std::string& type, - const VarNameList& inputs, - const VarNameList& outputs, - const AttributeMap& attrs) { - auto op_create_it = op_creators().find(type); - PADDLE_ENFORCE(op_create_it != op_creators().end(), - "Operator %s cannot be found.", type); - - auto op = op_create_it->second(); - op->type_ = type; - op->inputs_ = inputs; - op->outputs_ = outputs; - - op->attrs_ = attrs; - op_checkers().at(type).Check(op->attrs_); - - GenerateTempVariableName(op); + struct OpInfo { + OpCreator creator_; + std::string grad_op_type_; + OpProto* proto_; + OpAttrChecker* checker_; + }; - { - auto var_index_it = VarIndexMaps().find(type); - if (var_index_it != VarIndexMaps().end()) { - op->in_out_idxs_ = var_index_it->second; - } + template + static void RegisterOp(const std::string& op_type, + const std::string& grad_op_type) { + PADDLE_ENFORCE(op_info_map().count(op_type) == 0, + "'%s' is registered more than once.", op_type); + OpInfo op_info; + op_info.creator_ = [](const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const AttributeMap& attrs) { + return new OpType(type, inputs, outputs, attrs); + }; + op_info.grad_op_type_ = grad_op_type; + if (std::type_index(typeid(ProtoMakerType)) != + std::type_index(typeid(NOPMaker))) { + op_info.proto_ = new OpProto; + op_info.checker_ = new OpAttrChecker; + auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); + maker.Validate(); + op_info.proto_->set_type(op_type); + PADDLE_ENFORCE( + op_info.proto_->IsInitialized(), + "Fail to initialize %s's OpProto, because %s is not initialized", + op_type, op_info.proto_->InitializationErrorString()); + } else { + op_info.proto_ = nullptr; + op_info.checker_ = nullptr; } - - op->Init(); - return std::shared_ptr(op); - } - - static std::shared_ptr CreateOp(const OpDesc& op_desc) { - std::vector inputs; - inputs.reserve((size_t)op_desc.inputs_size()); - std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), - std::back_inserter(inputs)); - - std::vector outputs; - outputs.reserve((size_t)op_desc.outputs_size()); - std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), - std::back_inserter(outputs)); - - AttributeMap attrs; - for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); + op_info_map().insert(std::make_pair(op_type, op_info)); + // register gradient op + if (!grad_op_type.empty()) { + RegisterOp(grad_op_type, ""); } - - return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static std::shared_ptr CreateGradOp(const OperatorBase& op) { - PADDLE_ENFORCE(!op.IsNetOp(), - "Use framework::Backward to get backward ops"); - std::shared_ptr grad_op(BuildGradOp(&op)); - grad_op->Init(); - return grad_op; - } + static std::unique_ptr CreateOp(const std::string& type, + const VarNameMap& inputs, + const VarNameMap& outputs, + AttributeMap attrs); - static std::unordered_map& protos() { - static std::unordered_map protos_; - return protos_; - } + static std::unique_ptr CreateOp(const OpDesc& op_desc); - static std::unordered_map& grad_ops() { - static std::unordered_map grad_ops_; - return grad_ops_; - } + static VarNameMap ConvertOpDescVarsToVarNameMap( + const google::protobuf::RepeatedPtrField& op_desc_vars); - static std::unordered_map>& - VarIndexMaps() { - static std::unordered_map> maps_; - return maps_; - } + static std::unique_ptr CreateGradOp(const OperatorBase& op); - static std::unordered_map& op_creators() { - static std::unordered_map op_creators_; - return op_creators_; - } - - private: - static std::unordered_map& op_checkers() { - static std::unordered_map op_checkers_; - return op_checkers_; - } - - static void GenerateTempVariableName(OperatorBase* op) { - static std::atomic gUniqId(0UL); - for (auto& outname : op->outputs_) { - if (outname == kTempVarName) { - outname += op->type_; - outname += "@"; - outname += std::to_string(gUniqId.fetch_add(1)); - } - } + static std::unordered_map& op_info_map() { + static std::unordered_map op_info_map_; + return op_info_map_; } }; class Registrar { public: - // In our design, various kinds of classes, e.g., operators and kernels, have - // their corresponding registry and registrar. The action of registration is - // in the constructor of a global registrar variable, which, however, are not - // used in the code that calls package framework, and would be removed from - // the generated binary file by the linker. To avoid such removal, we add - // Touch to all registrar classes and make USE_OP macros to call this - // method. So, as long as the callee code calls USE_OP, the global + // In our design, various kinds of classes, e.g., operators and kernels, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which, + // however, are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_OP macros to + // call this method. So, as long as the callee code calls USE_OP, the global // registrar variable won't be removed by the linker. void Touch() {} }; -template +template class OpRegistrar : public Registrar { public: - explicit OpRegistrar(const char* op_type) { - OpRegistry::RegisterOp(op_type); - } -}; - -template -class GradOpRegistrar : public Registrar { - public: - GradOpRegistrar(const char* op_type, const char* grad_op_type) { - OpRegistry::RegisterGradOp(op_type, grad_op_type); + explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } + OpRegistrar(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterOp(op_type, + grad_op_type); } }; @@ -356,30 +140,30 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - static ::paddle::framework::OpRegistrar \ - __op_registrar_##op_type##__(#op_type); \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + class _OpGradClass_##op_type##_ : public grad_op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ + }; \ + static ::paddle::framework::OpRegistrar< \ + _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ + __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } -/** - * Macro to register Gradient Operator. - */ -#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##op_type##_##grad_op_type, \ - "REGISTER_GRADIENT_OP must be called in global namespace"); \ - static ::paddle::framework::GradOpRegistrar \ - __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ - #grad_op_type); \ - int TouchOpGradientRegistrar_##op_type() { \ - __op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \ - return 0; \ - } +#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ + REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) /** * Macro to register OperatorKernel. @@ -395,14 +179,6 @@ class OpKernelRegistrar : public Registrar { return 0; \ } -/** - * Macro to Forbid user register Gradient Operator. - */ -#define NO_GRADIENT(op_type) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##op_type##_##op_type##_grad, \ - "NO_GRADIENT must be called in global namespace") - #define REGISTER_OP_GPU_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) @@ -410,7 +186,8 @@ class OpKernelRegistrar : public Registrar { REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** - * Macro to mark what Operator and Kernel we will use and tell the compiler to + * Macro to mark what Operator and Kernel + * we will use and tell the compiler to * link them into target. */ #define USE_OP_ITSELF(op_type) \ @@ -421,23 +198,6 @@ class OpKernelRegistrar : public Registrar { static int use_op_itself_##op_type##_ __attribute__((unused)) = \ TouchOpRegistrar_##op_type() -// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use -// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't -// be compiled. `NO_GRAD` should be removed after all gradient ops are -// compeleted. -#define NO_GRAD -#ifndef NO_GRAD -#define USE_OP_GRADIENT(op_type) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __use_op_gradient_##op_type, \ - "USE_OP_GRADIENT must be called in global namespace"); \ - extern int TouchOpGradientRegistrar_##op_type(); \ - static int use_op_gradient_##op_type##_ __attribute__((unused)) = \ - TouchOpGradientRegistrar_##op_type() -#else -#define USE_OP_GRADIENT(op_type) -#endif - #define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ @@ -447,7 +207,8 @@ class OpKernelRegistrar : public Registrar { __attribute__((unused)) = \ TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() -// TODO(fengjiayi): The following macros seems ugly, do we have better method? +// TODO(fengjiayi): The following macros +// seems ugly, do we have better method? #ifdef PADDLE_ONLY_CPU #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) @@ -457,18 +218,13 @@ class OpKernelRegistrar : public Registrar { USE_OP_DEVICE_KERNEL(op_type, GPU) #endif -#define USE_NO_GRAD_OP(op_type) \ - USE_OP_ITSELF(op_type); \ - USE_OP_KERNEL(op_type) +#define USE_CPU_ONLY_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_DEVICE_KERNEL(op_type, CPU); -#define USE_CPU_OP(op_type) \ - USE_OP_ITSELF(op_type); \ - USE_OP_DEVICE_KERNEL(op_type, CPU); \ - USE_OP_GRADIENT(op_type) - -#define USE_OP(op_type) \ - USE_NO_GRAD_OP(op_type); \ - USE_OP_GRADIENT(op_type) +#define USE_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_KERNEL(op_type) } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index e64126c7093a8eebc219afa4979d941ddc1afc97..50c45919c53af22665feeeebe753da283ded2b0c 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -7,8 +7,7 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase) - + using OperatorBase::OperatorBase; void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} void InferShape(const Scope& scope) const override {} @@ -29,8 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase) - + using OperatorBase::OperatorBase; void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} @@ -40,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op").SetMultiple(); - AddOutput("output", "output of cosine op").SetTemporary(); + AddInput("input", "input of cosine op").AsDuplicable(); + AddOutput("output", "output of cosine op").AsIntermediate(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; @@ -53,16 +51,24 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle -REGISTER_OP(cos_sim, paddle::framework::CosineOp, - paddle::framework::CosineOpProtoAndCheckerMaker); -REGISTER_OP(my_test_op, paddle::framework::MyTestOp, - paddle::framework::MyTestOpProtoAndCheckerMaker); +static void BuildVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + var->add_arguments(arg_name); + } +} +REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp, + paddle::framework::CosineOpProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp, + paddle::framework::MyTestOpProtoAndCheckerMaker); TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - op_desc.add_inputs("aa"); - op_desc.add_outputs("bb"); + BuildVar("input", {"aa"}, op_desc.add_inputs()); + BuildVar("output", {"bb"}, op_desc.add_outputs()); float scale = 3.3; auto attr = op_desc.mutable_attrs()->Add(); @@ -70,8 +76,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - std::shared_ptr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -82,8 +87,8 @@ TEST(OpRegistry, CreateOp) { TEST(OpRegistry, IllegalAttr) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - op_desc.add_inputs("aa"); - op_desc.add_outputs("bb"); + BuildVar("input", {"aa"}, op_desc.add_inputs()); + BuildVar("output", {"bb"}, op_desc.add_outputs()); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -107,33 +112,23 @@ TEST(OpRegistry, IllegalAttr) { TEST(OpRegistry, DefaultValue) { paddle::framework::OpDesc op_desc; op_desc.set_type("cos_sim"); - op_desc.add_inputs("aa"); - op_desc.add_outputs("bb"); + BuildVar("input", {"aa"}, op_desc.add_inputs()); + BuildVar("output", {"bb"}, op_desc.add_outputs()); ASSERT_TRUE(op_desc.IsInitialized()); - std::shared_ptr op = - paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); ASSERT_EQ(op->GetAttr("scale"), 1.0); } -static void SetInputFormat(paddle::framework::OpDesc* desc) { - auto attr = desc->add_attrs(); - attr->set_name("input_format"); - attr->set_type(paddle::framework::INTS); - attr->mutable_ints()->Add(0); - attr->mutable_ints()->Add(1); -} - TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); - op_desc.add_inputs("ii"); - op_desc.add_outputs("oo"); - SetInputFormat(&op_desc); + BuildVar("input", {"ii"}, op_desc.add_inputs()); + BuildVar("output", {"oo"}, op_desc.add_outputs()); // attr 'test_attr' is not set bool caught = false; @@ -173,7 +168,6 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - SetInputFormat(&op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; paddle::framework::Scope scope; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index d9a013b883abdec4422806f90e36da7410a4fa0c..eadd8f3316ff1ebffb94a56b2e62d661e4e0b38f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "paddle/framework/operator.h" +#include +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -34,83 +34,172 @@ ExecutionContext::GetEigenDevice() const { #endif const std::string& OperatorBase::Input(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, - "Input Output Indices could not be nullptr"); - auto it = in_out_idxs_->find(name); - PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", - name); - if (attrs_.count("input_format") == 0) { - return inputs_.at((size_t)it->second); - } else { - const auto& input_format = GetAttr>("input_format"); - int idx = input_format[it->second]; - return inputs_.at((size_t)idx); - } + auto& ins = Inputs(name); + PADDLE_ENFORCE_EQ(ins.size(), 1UL, + "Op %s input %s should contain only one variable", type_, + name); + return ins[0]; } -std::vector OperatorBase::Inputs(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "IO Idx could not be nullptr"); - auto input_format = GetAttr>("input_format"); - auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= - static_cast(inputs_.size()), - "Input Out Of Range"); - - return std::vector{ - inputs_.begin() + input_format.at(offset), - inputs_.begin() + input_format.at(offset + 1)}; +const std::vector& OperatorBase::Inputs( + const std::string& name) const { + auto it = inputs_.find(name); + PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_, + name); + return it->second; } const std::string& OperatorBase::Output(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr"); - auto it = in_out_idxs_->find(name); - PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_", - name); - if (attrs_.count("output_format") == 0) { - return outputs_.at((size_t)it->second); - } else { - const auto& output_format = GetAttr>("output_format"); - int idx = output_format[it->second]; - return outputs_.at((size_t)idx); - } + auto& outs = Outputs(name); + PADDLE_ENFORCE_EQ(outs.size(), 1UL, + "Op %s output %s should contain only one variable", type_, + name); + return outs[0]; } -std::vector OperatorBase::Outputs(const std::string& name) const { - PADDLE_ENFORCE_NOT_NULL(in_out_idxs_, "InOut Indice could not be nullptr"); - auto output_format = GetAttr>("output_format"); - auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= - static_cast(outputs_.size()), - "Output Out of Range"); - return std::vector{ - outputs_.begin() + output_format.at(offset), - outputs_.begin() + output_format.at(offset + 1)}; +const std::vector& OperatorBase::Outputs( + const std::string& name) const { + auto it = outputs_.find(name); + PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_, + name); + return it->second; } std::string OperatorBase::DebugString() const { std::stringstream ss; - ss << "Op(" << type_ << "), inputs:("; - for (size_t i = 0; i < inputs_.size(); ++i) { - ss << inputs_[i]; - if (i != inputs_.size() - 1) { + ss << "Op(" << type_ << "), inputs:{"; + for (auto it = inputs_.begin(); it != inputs_.end();) { + auto& input = *it; + ss << input.first << "["; + for (size_t i = 0; i < input.second.size(); ++i) { + ss << input.second[i]; + if (i != input.second.size() - 1) { + ss << ", "; + } + } + ss << "]"; + ++it; + if (it != inputs_.end()) { ss << ", "; } } - ss << "), outputs:("; - for (size_t i = 0; i < outputs_.size(); ++i) { - ss << outputs_[i]; - if (i != outputs_.size() - 1) { + ss << "}, outputs:{"; + for (auto it = outputs_.begin(); it != outputs_.end();) { + auto& output = *it; + ss << output.first << "["; + for (size_t i = 0; i < output.second.size(); ++i) { + ss << output.second[i]; + if (i != output.second.size() - 1) { + ss << ", "; + } + } + ss << "]"; + ++it; + if (it != outputs_.end()) { ss << ", "; } } - ss << ")."; + ss << "}."; return ss.str(); } void OperatorBase::Rename(const std::string& old_name, const std::string& new_name) { - std::replace(inputs_.begin(), inputs_.end(), old_name, new_name); - std::replace(outputs_.begin(), outputs_.end(), old_name, new_name); + for (auto& input : inputs_) { + std::replace(input.second.begin(), input.second.end(), old_name, new_name); + } + for (auto& output : outputs_) { + std::replace(output.second.begin(), output.second.end(), old_name, + new_name); + } +} + +OperatorBase::OperatorBase(const std::string& type, + const OperatorBase::VarNameMap& inputs, + const OperatorBase::VarNameMap& outputs, + const AttributeMap& attrs) + : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { + static std::atomic gUniqId(0UL); + for (auto& output : outputs_) { + for (auto& output_name : output.second) { + if (output_name == kTempVarName) { + output_name += type_; + output_name += "@"; + output_name += std::to_string(gUniqId.fetch_add(1)); + } + } + } +} + +std::vector OperatorBase::OutputVars(bool has_intermediate) const { + std::vector ret_val; + if (has_intermediate) { + // push all outputs into ret_val + for (auto& o : outputs_) { + ret_val.reserve(ret_val.size() + o.second.size()); + ret_val.insert(ret_val.end(), o.second.begin(), o.second.end()); + } + return ret_val; + } + auto it = OpRegistry::op_info_map().find(type_); + PADDLE_ENFORCE( + it != OpRegistry::op_info_map().end(), + "Operator %s not registered, cannot figure out intermediate outputs", + type_); + PADDLE_ENFORCE( + it->second.proto_ != nullptr, + "Operator %s has no OpProto, cannot figure out intermediate outputs", + type_); + + // get all OpProto::Var for outputs + for (auto& o : it->second.proto_->outputs()) { + // ignore all intermediate output + if (o.intermediate()) continue; + auto out = outputs_.find(o.name()); + if (out != outputs_.end()) { + ret_val.reserve(ret_val.size() + out->second.size()); + ret_val.insert(ret_val.end(), out->second.begin(), out->second.end()); + } + } + return ret_val; +} + +void OpProtoAndCheckerMaker::Validate() { + validated_ = true; + CheckNoDuplicatedInOutAttrs(); +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( + const std::string& name, const std::string& comment) { + auto* input = proto_->add_inputs(); + input->set_name(name); + input->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{input}; +} + +OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( + const std::string& name, const std::string& comment) { + auto* output = proto_->add_outputs(); + output->set_name(name); + output->set_comment(comment); + return OpProtoAndCheckerMaker::VariableBuilder{output}; +} + +void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { + std::unordered_set names; + auto checker = [&](const std::string& name) { + PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name); + names.insert(name); + }; + for (auto& attr : proto_->attrs()) { + checker(attr.name()); + } + for (auto& input : proto_->inputs()) { + checker(input.name()); + } + for (auto& output : proto_->outputs()) { + checker(output.name()); + } } } // namespace framework diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 68e7fedcd6102435a3c30326aa91043b8abecb9e..807298088981b969622174be753ea0da72067243 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -20,8 +20,7 @@ limitations under the License. */ #include #include "paddle/framework/attribute.h" -#include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/framework.pb.h" #include "paddle/framework/scope.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -63,16 +62,10 @@ class ExecutionContext; */ class OperatorBase { public: - OperatorBase() {} // TODO(yi): This constructor is to be removed. - OperatorBase(const std::string& type, const std::vector& inputs, - const std::vector& outputs, - const AttributeMap& attrs, - std::unordered_map* in_out_idxs) - : type_(type), - inputs_(inputs), - outputs_(outputs), - attrs_(attrs), - in_out_idxs_(in_out_idxs) {} + using VarNameMap = std::map>; + + OperatorBase(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs); virtual ~OperatorBase() {} @@ -85,10 +78,6 @@ class OperatorBase { virtual std::string DebugString() const; - /// Init will be called after CreateOperator, you can put some initialization - /// logic here. - virtual void Init() {} - /// InferShape infer the size of Variables used by this Operator with /// information inside scope virtual void InferShape(const Scope& scope) const = 0; @@ -104,39 +93,134 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); + const VarNameMap& Inputs() const { return inputs_; } + const VarNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. - //! TODO add a vector_view to prevent memory copy. - std::vector Inputs(const std::string& name) const; + const std::vector& Inputs(const std::string& name) const; //! Get a output with argument's name described in `op_proto` const std::string& Output(const std::string& name) const; //! Get an output which has multiple variables. //! TODO add a vector_view to prevent memory copy. - std::vector Outputs(const std::string& name) const; + const std::vector& Outputs(const std::string& name) const; + + virtual std::vector OutputVars(bool has_intermediate) const; - const std::string Type() const { return type_; } - const std::vector Inputs() const { return inputs_; } - const std::vector Outputs() const { return outputs_; } + const std::string& Type() const { return type_; } + void SetType(const std::string& type) { type_ = type; } const AttributeMap& Attrs() const { return attrs_; } - const std::unordered_map* InOutIdx() const { - return in_out_idxs_.get(); - } - public: + // Return a new operator instance, which is as same as this. + // Use unique_ptr to prevent caller forget to delete this pointer. + virtual std::unique_ptr Clone() const = 0; + + protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs) + // I (Inputs)opear // O (Outputs) // OG (Output Gradients) - std::vector inputs_; + VarNameMap inputs_; + // NOTE: in case of OpGrad, outputs_ contains // IG (Inputs Gradients) - std::vector outputs_; + VarNameMap outputs_; AttributeMap attrs_; - // store the arguments' offset described in op_desc. - std::shared_ptr> in_out_idxs_; +}; + +// Macro for define a clone method. +// If you are writing an kernel operator, `Clone` will be defined when you +// register it. i.e. `Clone` method is not needed to define by yourself. +#define DEFINE_OP_CLONE_METHOD(CLS) \ + std::unique_ptr Clone() const final { \ + return std::unique_ptr(new CLS(*this)); \ + } + +// Macro for define a default constructor for Operator. +// You can also use +// using PARENT_CLASS::PARENT_CLASS; +// to use parent's constructor. +#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \ + CLS(const std::string& type, const VarNameMap& inputs, \ + const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ + : PARENT_CLS(type, inputs, outputs, attrs) {} + +class NOP : public OperatorBase { + public: + using OperatorBase::OperatorBase; + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} + std::unique_ptr Clone() const override { + return std::unique_ptr(new NOP(*this)); + } +}; + +// this class not only make proto but also init attribute checkers. +class OpProtoAndCheckerMaker { + public: + OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : proto_(proto), op_checker_(op_checker) {} + + ~OpProtoAndCheckerMaker() { + PADDLE_ENFORCE(validated_, "should call Validate after build"); + } + + void Validate(); + + protected: + struct VariableBuilder { + OpProto::Var* var_; + + VariableBuilder& AsDuplicable() { + var_->set_duplicable(true); + return *this; + } + + VariableBuilder& AsIntermediate() { + var_->set_intermediate(true); + return *this; + } + + VariableBuilder& NotInGradient() { + var_->set_not_in_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, const std::string& comment); + + VariableBuilder AddOutput(const std::string& name, + const std::string& comment); + + template + TypedAttrChecker& AddAttr(const std::string& name, + const std::string& comment, + bool generated = false) { + auto* attr = proto_->add_attrs(); + attr->set_name(name); + attr->set_comment(comment); + attr->set_generated(generated); + attr->set_type(AttrTypeID()); + return op_checker_->AddAttrChecker(name); + } + + void AddComment(const std::string& comment) { proto_->set_comment(comment); } + + private: + void CheckNoDuplicatedInOutAttrs(); + + OpProto* proto_; + OpAttrChecker* op_checker_; + bool validated_{false}; +}; + +class NOPMaker : public OpProtoAndCheckerMaker { + public: + NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} }; class InferShapeContext { @@ -144,16 +228,12 @@ class InferShapeContext { InferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} - size_t InputSize() const { return op_.inputs_.size(); } - - size_t OutputSize() const { return op_.outputs_.size(); } - - const Variable* InputVar(const size_t index) const { - return scope_.FindVar(op_.inputs_.at(index)); + size_t InputSize(const std::string& name) const { + return op_.Inputs(name).size(); } - Variable* OutputVar(const size_t index) const { - return scope_.FindVar(op_.outputs_.at(index)); + size_t OutputSize(const std::string& name) const { + return op_.Outputs(name).size(); } const Variable* InputVar(const std::string& name) const { @@ -185,27 +265,9 @@ class InferShapeContext { return res; } - template - const T* Input(const size_t index) const { - auto var = InputVar(index); - PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index); - return &var->Get(); - } - - template - T* Output(const size_t index) const { - auto var = OutputVar(index); - PADDLE_ENFORCE_NOT_NULL( - var, - "Output(%d) not be nullptr, which means variable [%s] does not " - "exist in scope", - index, op_.outputs_[index]); - return var->GetMutable(); - } - template const T* Input(const std::string& name) const { - auto var = InputVar(name); + auto* var = InputVar(name); PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name); return &var->Get(); } @@ -242,7 +304,7 @@ class InferShapeContext { [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); PADDLE_ENFORCE_NOT_NULL( - var, "MultiOutput(%s:%s) should not be nullptr", name, + var, "MultiOutput(%s:%s) should not be nullptr.", name, sub_name); return var->GetMutable(); }); @@ -281,6 +343,10 @@ class ExecutionContext : public InferShapeContext { platform::Place GetPlace() const { return device_context_->GetPlace(); } + const platform::DeviceContext* device_context() const { + return device_context_; + } + const platform::DeviceContext* device_context_; }; @@ -300,14 +366,6 @@ class OpKernel { class OperatorWithKernel : public OperatorBase { public: - OperatorWithKernel() {} // TODO(yi): This constructor is to be removed. - OperatorWithKernel(const std::string& type, - const std::vector& inputs, - const std::vector& outputs, - const AttributeMap& attrs, - std::unordered_map* in_out_idxs) - : OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {} - struct OpKernelKey { platform::Place place_; @@ -331,6 +389,10 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; + OperatorWithKernel(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const Scope& scope) const override { InferShape(InferShapeContext(*this, scope)); } @@ -357,15 +419,5 @@ class OperatorWithKernel : public OperatorBase { virtual void InferShape(const InferShapeContext& ctx) const = 0; }; -#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \ - public: \ - Class() { /* TODO(yi): This constructor is to be removed. */ \ - } \ - Class(const std::string& type, const std::vector& inputs, \ - const std::vector& outputs, \ - const ::paddle::framework::AttributeMap& attrs, \ - std::unordered_map* in_out_idxs) \ - : ParentClass(type, inputs, outputs, attrs, in_out_idxs) {} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 7dbd5b14ab6ec89ae9940a3d12ec9d2b169153ad..2425b87779f6af01b0e8a91b5f574a28385f0efd 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -23,22 +23,22 @@ static int op_run_num = 0; class OpWithoutKernelTest : public OperatorBase { public: - DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase) - - void Init() override { x = 1; } + OpWithoutKernelTest(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs), x(1) {} void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { - op_run_num++; - ASSERT_EQ((int)inputs_.size(), 1); - ASSERT_EQ((int)outputs_.size(), 1); - ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr); + ++op_run_num; + ASSERT_EQ(static_cast(inputs_.size()), 1); + ASSERT_EQ(static_cast(outputs_.size()), 1); + ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr); ASSERT_EQ(x, 1); - ASSERT_NE(scope.FindVar(outputs_[0]), nullptr); + ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr); } public: - float x = 0; + int x{0}; }; class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -56,14 +56,25 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { } // namespace framework } // namespace paddle -REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, - paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); +static void BuildVar(const std::string& param_name, + std::initializer_list arguments, + paddle::framework::OpDesc::Var* var) { + var->set_parameter(param_name); + for (auto& arg_name : arguments) { + *var->mutable_arguments()->Add() = arg_name; + } +} + +REGISTER_OP_WITHOUT_GRADIENT( + test_operator, paddle::framework::OpWithoutKernelTest, + paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("test_operator"); - *op_desc.mutable_inputs()->Add() = "IN1"; - *op_desc.mutable_outputs()->Add() = "OUT1"; + BuildVar("input", {"IN1"}, op_desc.add_inputs()); + BuildVar("output", {"OUT1"}, op_desc.add_outputs()); + auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); @@ -100,7 +111,8 @@ static int cpu_kernel_run_num = 0; class OpWithKernelTest : public OperatorWithKernel { public: - DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel) + using OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext& ctx) const override {} }; @@ -117,35 +129,15 @@ class CPUKernelTest : public OpKernel { } }; -// multiple inputs test -class OperatorMultiInputsTest : public OperatorBase { - public: - DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase) - - void Init() override { x = 1; } - void InferShape(const Scope& scope) const override {} - void Run(const Scope& scope, - const platform::DeviceContext& dev_ctx) const override { - ASSERT_EQ(scope.FindVar(inputs_[0]), nullptr); - ASSERT_EQ(x, 1); - ASSERT_NE(scope.FindVar(outputs_[0]), nullptr); - ASSERT_EQ(Input("x"), "IN1"); - ASSERT_EQ(Input("y"), "OUT1"); - } - - public: - float x = 0; -}; - class OpKernelTestMultiInputsProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("xs", "inputs of test op").SetMultiple(); + AddInput("xs", "inputs of test op").AsDuplicable(); AddInput("k", "input of test op"); - AddOutput("ys", "outputs of test op").SetMultiple(); + AddOutput("ys", "outputs of test op").AsDuplicable(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); @@ -193,8 +185,9 @@ class CPUKernalMultiInputsTest : public OpKernel { } // namespace framework } // namespace paddle -REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, - paddle::framework::OpKernelTestProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + op_with_kernel, paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestProtoAndCheckerMaker); REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); @@ -202,8 +195,9 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel, TEST(OpKernel, all) { paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); - *op_desc.mutable_inputs()->Add() = "IN1"; - *op_desc.mutable_outputs()->Add() = "OUT1"; + BuildVar("x", {"IN1"}, op_desc.add_inputs()); + BuildVar("y", {"OUT1"}, op_desc.add_outputs()); + auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); @@ -218,8 +212,9 @@ TEST(OpKernel, all) { ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); } -REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, - paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, paddle::framework::CPUKernalMultiInputsTest); @@ -229,32 +224,15 @@ TEST(OpKernel, multi_inputs) { OpDesc op_desc; op_desc.set_type("op_multi_inputs_with_kernel"); - *op_desc.mutable_inputs()->Add() = "x0"; - *op_desc.mutable_inputs()->Add() = "x1"; - *op_desc.mutable_inputs()->Add() = "x2"; - *op_desc.mutable_inputs()->Add() = "k0"; - *op_desc.mutable_outputs()->Add() = "y0"; - *op_desc.mutable_outputs()->Add() = "y1"; + BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs()); + BuildVar("k", {"k0"}, op_desc.add_inputs()); + BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs()); + auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(3.14); - auto attr0 = op_desc.mutable_attrs()->Add(); - attr0->set_name("input_format"); - attr0->set_type(paddle::framework::AttrType::INTS); - auto input_format = attr0->mutable_ints(); - input_format->Add(0); // x0 - input_format->Add(3); // k - input_format->Add(4); // end - - auto attr1 = op_desc.mutable_attrs()->Add(); - attr1->set_name("output_format"); - attr1->set_type(paddle::framework::AttrType::INTS); - auto output_format = attr1->mutable_ints(); - output_format->Add(0); // y0 - output_format->Add(2); // y1 - paddle::platform::CPUDeviceContext cpu_device_context; paddle::framework::Scope scope; scope.NewVar("x0")->GetMutable(); @@ -267,3 +245,21 @@ TEST(OpKernel, multi_inputs) { auto op = paddle::framework::OpRegistry::CreateOp(op_desc); op->Run(scope, cpu_device_context); } + +class OperatorClone : public paddle::framework::OperatorBase { + public: + DEFINE_OP_CLONE_METHOD(OperatorClone); + OperatorClone(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const paddle::framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void InferShape(const paddle::framework::Scope& scope) const override {} + void Run(const paddle::framework::Scope& scope, + const paddle::platform::DeviceContext& dev_ctx) const override {} +}; + +TEST(Operator, Clone) { + OperatorClone a("ABC", {}, {}, {}); + auto b = a.Clone(); + ASSERT_EQ(a.Type(), b->Type()); +} \ No newline at end of file diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 75cd5bcb38e1d864358314c1c15b6fb59e9c3752..de119e9e062bffa5e95929671c69001047772cf3 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/framework/tensor_py.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "paddle/string/to_string.h" @@ -30,8 +31,8 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_CPU_OP(onehot_cross_entropy); -USE_NO_GRAD_OP(sgd); +USE_OP(onehot_cross_entropy); +USE_OP(sgd); USE_OP(mul); USE_OP(mean); USE_OP(sigmoid); @@ -47,41 +48,6 @@ namespace framework { using Tensor = framework::Tensor; -template -void ExposeOperator(ClassType &m) { - m.def("infer_shape", &ClassType::type::InferShape) - .def("run", &ClassType::type::Run) - .def("type", - [](const typename ClassType::type &op) -> std::string { - return op.type_; - }) - .def("outputs", - [](const typename ClassType::type &op) -> std::vector { - return op.outputs_; - }) - .def("inputs", - [](const typename ClassType::type &op) -> std::vector { - return op.inputs_; - }) - .def("support_gpu", &ClassType::type::SupportGPU) - .def("temp_outputs", - [](const typename ClassType::type &op) -> std::vector { - auto iter = op.attrs_.find("temporary_index"); - std::vector ret; - if (iter == op.attrs_.end()) { - return ret; - } else { - auto tmp_idx = boost::get>(iter->second); - ret.reserve(tmp_idx.size()); - for (auto &index : tmp_idx) { - ret.push_back(op.outputs_.at(index)); - } - return ret; - } - }) - .def("__str__", &ClassType::type::DebugString); -} - static size_t UniqueIntegerGenerator() { static std::atomic generator; return generator.fetch_add(1); @@ -172,13 +138,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &protos = OpRegistry::protos(); + auto &op_info_map = OpRegistry::op_info_map(); std::vector ret_values; - for (auto it = protos.begin(); it != protos.end(); ++it) { - PADDLE_ENFORCE(it->second.IsInitialized(), - "OpProto must all be initialized"); + for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) { + const OpProto *proto = it->second.proto_; + if (proto == nullptr) { + continue; + } + PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized"); std::string str; - PADDLE_ENFORCE(it->second.SerializeToString(&str), + PADDLE_ENFORCE(proto->SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); ret_values.push_back(py::bytes(str)); } @@ -215,47 +184,69 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init<>()) .def("__str__", string::to_string); - py::class_> operator_base( - m, "Operator"); - - operator_base.def_static("create", [](py::bytes protobin) { - OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return OpRegistry::CreateOp(desc); - }); - - operator_base.def("backward", - [](const OperatorBase &forwardOp, - const std::unordered_set &no_grad_vars) { - return Backward(forwardOp, no_grad_vars); - }); - - ExposeOperator(operator_base); - - py::class_> net(m, "Net"); - - net.def_static("create", - []() -> std::shared_ptr { - auto retv = std::make_shared(); - retv->type_ = "plain_net"; - return retv; - }) - .def("add_op", &operators::NetOp::AddOp) - .def("add_op", - [](operators::NetOp &self, - const std::shared_ptr &net) -> void { - self.AddOp(std::static_pointer_cast(net)); + py::class_(m, "Operator") + .def_static("create", + [](py::bytes protobin) { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return OpRegistry::CreateOp(desc); + }) + .def("backward", + [](const OperatorBase &forwardOp, + const std::unordered_set &no_grad_vars) { + return Backward(forwardOp, no_grad_vars).release(); }) + .def("infer_shape", &OperatorBase::InferShape) + .def("run", &OperatorBase::Run) + .def("type", + [](const OperatorBase &op) -> std::string { return op.Type(); }) + .def("outputs", + [](const OperatorBase &op) + -> std::map> { + return op.Outputs(); + }) + .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) + .def("__str__", &OperatorBase::DebugString) + .def("no_intermediate_outputs", + [](const OperatorBase &op) { return op.OutputVars(false); }) + .def("support_gpu", &OperatorBase::SupportGPU); + + py::class_(m, "Net") + .def_static("create", + []() -> operators::NetOp * { + auto *retv = new operators::NetOp; + retv->SetType("plain_net"); + return retv; + }) + .def("append_op", [](operators::NetOp &self, + const OperatorBase &op) { self.AppendOp(op); }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); }); - ExposeOperator(net); + // recurrent_op + py::class_(m, "RecurrentOp") + .def_static( + "create", + [](py::bytes protobin) -> operators::RecurrentOp * { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto rnn_op = OpRegistry::CreateOp(desc); + return static_cast(rnn_op.release()); + }) + .def("set_stepnet", [](operators::RecurrentOp &self, + const operators::NetOp &net) -> void { + self.set_stepnet(net.Clone()); + }); m.def("unique_integer", UniqueIntegerGenerator); diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index cd1b4de426a49fa66dbbf8cf7d09990ac8d21227..b8c779f4e5fc7bc51298cdd35b26c2c8ac98edf6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -105,6 +105,8 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; + platform::Place place() const { return holder_->place(); } + private: template inline void check_memory_size() const; diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index 7dfb6f61c50959f7269725a00dbc4f9c27474bdf..c572a9d433bc16e6733b8fc9367970bef28e699a 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -4,6 +4,10 @@ file(GLOB cpp_files . *Op.cpp) list(APPEND h_files Function.h) list(APPEND cpp_files Function.cpp) list(APPEND cpp_files BufferArg.cpp) +list(APPEND cpp_files GemmFunctor.cpp) +if(USE_EIGEN_FOR_BLAS) + list(APPEND cpp_files EigenGemm.cpp) +endif(USE_EIGEN_FOR_BLAS) if(WITH_GPU) file(GLOB cu_files . *OpGpu.cu) diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp index 490e8d546cbd460217abe95f6291b13fa207faa9..2f3112fe657cd381891dc53c7179e7520911e8c9 100644 --- a/paddle/function/DepthwiseConvOp.cpp +++ b/paddle/function/DepthwiseConvOp.cpp @@ -14,7 +14,6 @@ limitations under the License. */ #include "DepthwiseConvOp.h" #include "ConvOp.h" -#include "GemmFunctor.h" namespace paddle { diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu index 33463805cbd4746c05548028e0bc4a0e2a90453e..2d722dfcfca0f328edeecf185ea37b8512b91907 100644 --- a/paddle/function/DepthwiseConvOpGpu.cu +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "DepthwiseConvOp.h" -#include "GemmFunctor.h" #include "paddle/math/BaseMatrix.h" namespace paddle { diff --git a/paddle/function/EigenGemm.cpp b/paddle/function/EigenGemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..674141ed39b7f5573948348e3ba3bb526ae43c66 --- /dev/null +++ b/paddle/function/EigenGemm.cpp @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { + +template +struct EigenBlasGemm { + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + Eigen::array sizeA; + if (transA) { + sizeA[0] = K; + sizeA[1] = M; + CHECK_EQ(M, lda); + } else { + sizeA[0] = M; + sizeA[1] = K; + CHECK_EQ(K, lda); + } + Eigen::array sizeB; + if (transB) { + sizeB[0] = N; + sizeB[1] = K; + CHECK_EQ(K, ldb); + } else { + sizeB[0] = K; + sizeB[1] = N; + CHECK_EQ(N, ldb); + } + Eigen::array sizeC; + sizeC[0] = M; + sizeC[1] = N; + CHECK_EQ(N, ldc); + + const Matrix a(const_cast(A), sizeA); + const Matrix b(const_cast(B), sizeB); + Matrix c(C, sizeC); + + typedef typename Eigen::Tensor::DimensionPair DimPair; + Eigen::array dims; + dims[0] = DimPair(1, 0); + dims[0].first = transA ? 0 : 1; + dims[0].second = transB ? 1 : 0; + + Eigen::DefaultDevice device; + if (alpha == T(1) && beta == T(0)) { + c.device(device) = a.contract(b, dims); + } else if (alpha == T(1) && beta == T(1)) { + c.device(device) += a.contract(b, dims); + } else { + c.device(device) = alpha * a.contract(b, dims) + beta * c; + } + } +}; + +#ifdef PADDLE_TYPE_DOUBLE +template class EigenBlasGemm; +#else +template class EigenBlasGemm; +#endif + +} // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87..f8cf4ebea8d724f0291b981647622b63e3d84495 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -85,7 +85,6 @@ public: } Im2ColFunctor im2col; - GemmFunctor gemm; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -108,19 +107,19 @@ public: int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - K, - colData, - N, - beta, - outputData + g * outputOffset, - N); + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + colData, + N, + beta, + outputData + g * outputOffset, + N); } inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; @@ -188,8 +187,6 @@ public: } Col2ImFunctor col2im; - GemmFunctor gemm; - size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -205,19 +202,19 @@ public: colData = inputGrad + g * inputOffset; scale = 1.0f; } - gemm(CblasTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - M, - outputGrad + g * outputOffset, - N, - scale, - colData, - N); + BlasGemm::compute(true, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + M, + outputGrad + g * outputOffset, + N, + scale, + colData, + N); if (needIm2col) { col2im(inputGrad + g * inputOffset, imShape, @@ -299,7 +296,6 @@ public: } Im2ColFunctor im2col; - GemmFunctor gemm; size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -321,19 +317,19 @@ public: int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasTrans, - M, - N, - K, - 1.0f, - outputGrad + g * outputOffset, - K, - colData, - K, - i == 0 ? beta : 1.0f, - filterGrad + g * filterOffset, - N); + BlasGemm::compute(false, + true, + M, + N, + K, + 1.0f, + outputGrad + g * outputOffset, + K, + colData, + K, + i == 0 ? beta : 1.0f, + filterGrad + g * filterOffset, + N); } inputData += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; diff --git a/paddle/function/GemmFunctor.cpp b/paddle/function/GemmFunctor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9e25ee58a12490a1454436b3fe4a89176478d5c0 --- /dev/null +++ b/paddle/function/GemmFunctor.cpp @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "GemmFunctor.h" +#include "paddle/math/MathFunctions.h" + +namespace paddle { + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { +#ifdef PADDLE_USE_EIGEN_FOR_BLAS + EigenBlasGemm::compute( + transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +#else + gemm(transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +#endif + } +}; + +template +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + hl_matrix_mul((T*)A, + transA == false ? HPPL_OP_N : HPPL_OP_T, + (T*)B, + transB == false ? HPPL_OP_N : HPPL_OP_T, + C, + M, + N, + K, + alpha, + beta, + lda, + ldb, + ldc); + } +}; + +template struct BlasGemm; +template struct BlasGemm; + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h index d5db5cf5e7a855d89b262fe8cf42aa2c55f419f1..0809953b4eb17c25eadcce7f474a3dab0469bba1 100644 --- a/paddle/function/GemmFunctor.h +++ b/paddle/function/GemmFunctor.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/math/MathFunctions.h" +#include "TensorType.h" namespace paddle { @@ -24,73 +24,42 @@ namespace paddle { // of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul // interface. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc); +struct BlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; +// TODO(hedaoyuan): Since the definition of the real type in the Paddle +// conflicts with the Eigen library, so compile the Eigen code can not +// include the Paddle header file. And need an EigenBlasGemm template class +// that does not contain the DeviceType parameter. +// I will fix this problem and merge BlasGemm and EigenBlasGemm into one. template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - gemm(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); - } -}; - -template -class GemmFunctor { -public: - void operator()(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE TransB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc) { - hl_matrix_mul((T*)A, - transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - (T*)B, - TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, - C, - M, - N, - K, - alpha, - beta, - lda, - ldb, - ldc); - } +struct EigenBlasGemm { + static void compute(const bool transA, + const bool transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); }; } // namespace paddle diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index cfa80a89365af5111746eec9599d16e37532a9f7..26cff3e67710b2f38d93572c5d58849aa94a5135 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -202,7 +202,7 @@ void NeuralNetwork::prefetch(const std::vector& inArgs) { auto mat = dynamic_cast( para->getMat(PARAMETER_VALUE).get()); para->clearGradient(); - mat->clearIndices(); + if (mat) mat->clearIndices(); } } } diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index f98bf95064fa539b990309dfe0bff10c1e99d096..157b1ab45163a94a81d859dbcb7a52ae8edae439 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -184,7 +184,7 @@ public: } void backward(const UpdateCallback& callback) override { - if (biases_) { + if (biases_ && biases_->getWGrad()) { backwardActivation(); biases_->getWGrad()->collectBias(*getOutputGrad(), 1); biases_->getParameterPtr()->incUpdate(callback); diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index 30f567eaf8248a8fba1b461a2bdbf2aab13f9e08..d201fac65e0459050304195140e1aae081468f43 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -57,11 +57,14 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap, } void MKLDNNFcLayer::convertWeightsFromPaddle() { - if (FLAGS_use_mkldnn_wgt) { + if (hasInitedWgt_) { return; } - if (hasInitedWgt_) { + // TODO(TJ): dst format should get from wgtVal_ + int dstFmt = PARAM_FORMAT_MKLDNN_OI; + int srcFmt = weight_->getParameterPtr()->getHeaderFormat(); + if (srcFmt == dstFmt) { return; } @@ -78,6 +81,7 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { MatrixPtr paddleWgtT; paddleWgt->transpose(paddleWgtT, true); weight_->getW()->copyFrom(*paddleWgtT); + weight_->getParameterPtr()->setHeaderFormat(dstFmt); hasInitedWgt_ = true; } diff --git a/paddle/gserver/layers/ScaleShiftLayer.cpp b/paddle/gserver/layers/ScaleShiftLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35fd038ab43a8a8b08bc328b3d1b08a7bbedd0a1 --- /dev/null +++ b/paddle/gserver/layers/ScaleShiftLayer.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer applies a linear transformation to each element in each row of + * the input matrix. For each element, the layer first re-scale it and then + * adds a bias to it. + * + * \f[ + * y = wx + b + * \f] + * + * Here, w is the scale and b is the bias. Both w and b are trainable scalars. + * + */ + +class ScaleShiftLayer : public Layer { +protected: + std::unique_ptr scale_; + std::unique_ptr offset_; + +public: + explicit ScaleShiftLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(scale_shift, ScaleShiftLayer); + +bool ScaleShiftLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + CHECK_EQ(inputLayers_.size(), 1U); + scale_.reset(new Weight(1, 1, parameters_[0])); + if (biasParameter_.get() != NULL) { + offset_ = std::unique_ptr(new Weight(1, 1, biasParameter_)); + } + return true; +} + +void ScaleShiftLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + real scaleValue = scale_->getW()->getElement(0, 0); + outV->mulScalar(*inV, scaleValue); + if (offset_) { + real offsetValue = offset_->getW()->getElement(0, 0); + outV->add(offsetValue); + } +} + +void ScaleShiftLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + + /* Calculate the parameter gradient for the current layer */ + if (scale_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ij} * c_{ij} + rowSumMtx->sumOfProducts( + /* b= */ *inV, /* c= */ *outG, /* scaleSum= */ 1, /* scaleDest= */ 0.); + // this_i = scaleDest * this_i + scaleSum * \sum_j b_{ji} + scale_->getWGrad()->sumCols( + /* b= */ *rowSumMtx, /* scaleSum= */ 1., /* scaleDest= */ 1.); + scale_->getParameterPtr()->incUpdate(callback); + } + if (offset_ && offset_->getWGrad()) { + MatrixPtr rowSumMtx; + Matrix::resizeOrCreate(rowSumMtx, outG->getHeight(), 1, false, useGpu_); + rowSumMtx->sumRows(*outG, 1., 0.); + offset_->getWGrad()->sumCols(*rowSumMtx, 1., 1.); + offset_->getParameterPtr()->incUpdate(callback); + } + + /* Calculate the input layers error */ + if (inG) { + real scaleValue = scale_->getW()->getElement(0, 0); + inG->add(*outG, scaleValue); + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp index 99c8c4948c9b05ad15d1217ebb70026bbd48453f..de1635be2af37cd0ba49010199a417090865b0e4 100644 --- a/paddle/gserver/tests/MKLDNNTester.cpp +++ b/paddle/gserver/tests/MKLDNNTester.cpp @@ -330,9 +330,7 @@ void MKLDNNTester::run(const TestConfig& dnn, log_ = log; lvl_ = level; - // Firstly test FLAGS_use_mkldnn_wgt = false - FLAGS_use_mkldnn_wgt = false; - // reset and run once + // Firstly test mkldnn init from PARAM_FORMAT_ORIGINAL weight reset(dnn, ref, batchSize); randomWgtDatas(); clearWgtDiffs(); @@ -342,17 +340,32 @@ void MKLDNNTester::run(const TestConfig& dnn, runOnce(); } - // Then test FLAGS_use_mkldnn_wgt = true - FLAGS_use_mkldnn_wgt = true; - // after run once the mkldnn weight has been stored in dnnlayer + if (parameters_[DNN].empty()) { + // has no paramters + return; + } + + // After run some iterations, the mkldnn weight has been stored in dnnLayer + // and we can also get the mkldnn weight parameter header format. + // Weight parameter should always be index 0 (and bias index 1). + // TODO(TJ): should also consider mean and var format when batchnorm ready + int dnnWgtFmt = parameters_[DNN][0]->getHeaderFormat(); + int refWgtFmt = parameters_[REF][0]->getHeaderFormat(); + if (dnnWgtFmt == refWgtFmt) { + // weight format are equal, so no need check more + return; + } + // then save the weights and restart again vector dnnWgts, refWgts; CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size()); saveWgt(parameters_[DNN], dnnWgts); saveWgt(parameters_[REF], refWgts); - // restart again with flag true + // restart again with dnn weight format reset(dnn, ref, batchSize); + // TODO(TJ): should also considerate mean and var format when batchnorm ready + parameters_[DNN][0]->setHeaderFormat(dnnWgtFmt); // restore wgt restoreWgt(dnnWgts, parameters_[DNN]); diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h index 522eeaf24b1949abac057a1e59e9977610be23c0..e55e4493ffdfe45b8cfdee423febd1878b8b3d8a 100644 --- a/paddle/gserver/tests/MKLDNNTester.h +++ b/paddle/gserver/tests/MKLDNNTester.h @@ -108,7 +108,7 @@ private: * if many(>failRate) wrong(abs(dnn-ref)/abs(ref)>thres) points return the * max(diff/ref) * else return sum(abs(a-b)) / sum(abs(b)) - * The return value should smaller than eps when passing. + * The return value should be smaller than eps when passing. */ double getDelta(const real* d1, const real* d2, diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 0f312b6ca50bc1e6317251ba785f1c61a224b54e..dd2c955e6a4660a1811f205ec5c5861798291912 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2007,6 +2007,21 @@ TEST(Layer, RowL2NormLayer) { } } +TEST(Layer, ScaleShiftLayer) { + const size_t batchSize = 16; + const size_t size = 32; + TestConfig config; + config.layerConfig.set_type("scale_shift"); + config.layerConfig.set_size(size); + config.biasSize = 1; + config.inputDefs.push_back( + {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index f930c72fde3f5e0a6a45cb6bfd3507a4f48028fc..d36f72360f8ebd2033fb3e8c0e1b30911abba362 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -269,7 +269,8 @@ TEST(Compare, img_conv2) { bool useGpu = FLAGS_use_gpu; double eps = FLAGS_checkgrad_eps; FLAGS_use_gpu = true; - FLAGS_checkgrad_eps = 1e-2; + // Sometimes, this unit test will fail with 1e-2 + FLAGS_checkgrad_eps = 4e-2; compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; FLAGS_checkgrad_eps = eps; diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt index 8035d93bfec75b20a54c5af0521ab724cafba8ca..9cc4233e43267472d405c3e4e617f0782e1430ea 100644 --- a/paddle/memory/CMakeLists.txt +++ b/paddle/memory/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(detail) cc_library(memory SRCS memory.cc) -cc_library(memcpy SRCS memcpy.cc DEPS device_context) +cc_library(memcpy SRCS memcpy.cc) cc_library(paddle_memory DEPS diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index f61e67a32906083881dd7f47433521876be9b355..a270bd59581520859d43cddd2fc0cfa72080f46d 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -27,7 +27,7 @@ limitations under the License. */ // between host and device. Allocates too much would reduce the amount // of memory available to the system for paging. So, by default, we // should set false to use_pinned_memory. -DEFINE_bool(use_pinned_memory, false, "If set, allocate cpu pinned memory."); +DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); namespace paddle { namespace memory { diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index aaab1142ca18d3319469a4d685fde9d30929113f..a19a3e3675e3e2e7cc0c3594f21191f932d6379f 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -16,8 +16,6 @@ limitations under the License. */ #include // for memcpy -#include "paddle/platform/device_context.h" - namespace paddle { namespace memory { diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 207025f9b1c64f0f8943f9fae5edefc9328a1d26..29bc26f9d3bca0e30896657431f9a9bb1dac0d1d 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -13,22 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/memory/memory.h" + +#include // for transform +#include // for memcpy +#include // for unique_ptr +#include // for call_once + +#include "glog/logging.h" + #include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/gpu_info.h" -#include // for memcpy +DECLARE_double(fraction_of_gpu_memory_to_use); namespace paddle { namespace memory { -detail::BuddyAllocator* GetCPUBuddyAllocator() { - static detail::BuddyAllocator* a = nullptr; - if (a == nullptr) { - a = new detail::BuddyAllocator(new detail::CPUAllocator, - platform::CpuMinChunkSize(), - platform::CpuMaxChunkSize()); - } - return a; +using BuddyAllocator = detail::BuddyAllocator; + +std::once_flag cpu_allocator_flag; +std::once_flag gpu_allocator_flag; + +BuddyAllocator* GetCPUBuddyAllocator() { + static std::unique_ptr a{nullptr}; + + std::call_once(cpu_allocator_flag, [&]() { + a.reset(new BuddyAllocator(new detail::CPUAllocator, + platform::CpuMinChunkSize(), + platform::CpuMaxChunkSize())); + }); + + return a.get(); } template <> @@ -48,20 +64,36 @@ size_t Used(platform::CPUPlace place) { #ifndef PADDLE_ONLY_CPU -detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { - static detail::BuddyAllocator** as = NULL; - if (as == NULL) { +BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { + using BuddyAllocVec = std::vector; + static std::unique_ptr as{ + new BuddyAllocVec, [](BuddyAllocVec* p) { + std::for_each(p->begin(), p->end(), + [](BuddyAllocator* p) { delete p; }); + }}; + + // GPU buddy allocators + auto& allocators = *as.get(); + + // GPU buddy allocator initialization + std::call_once(gpu_allocator_flag, [&]() { int gpu_num = platform::GetDeviceCount(); - as = new detail::BuddyAllocator*[gpu_num]; + allocators.reserve(gpu_num); for (int gpu = 0; gpu < gpu_num; gpu++) { platform::SetDeviceId(gpu); - as[gpu] = new detail::BuddyAllocator(new detail::GPUAllocator, - platform::GpuMinChunkSize(), - platform::GpuMaxChunkSize()); + allocators.emplace_back(new BuddyAllocator(new detail::GPUAllocator, + platform::GpuMinChunkSize(), + platform::GpuMaxChunkSize())); } - } + VLOG(3) << "\n\nNOTE: each GPU device use " + << FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n" + << "You can set environment variable '" + << platform::kEnvFractionGpuMemoryToUse + << "' to change the fraction of GPU usage.\n\n"; + }); + platform::SetDeviceId(gpu_id); - return as[gpu_id]; + return allocators[gpu_id]; } template <> diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 72351b9dfa63513713463bb47a3684f0dfd84ad3..11bbb881874ec50e1132547336fc6fb6b42bcc4f 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/platform/gpu_info.h" #include "paddle/platform/place.h" namespace paddle { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index c181bd7b881c08dfd80d640b1ddce10b3c74d758..a7c89787e43df6173791bc54b3faffc034867f7d 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,8 +41,11 @@ function(op_library TARGET) endif() endfunction() +add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) +cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) + cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) @@ -50,7 +53,7 @@ op_library(add_op SRCS add_op.cc add_op.cu) op_library(mean_op SRCS mean_op.cc mean_op.cu) -op_library(mul_op SRCS mul_op.cc mul_op.cu) +op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) @@ -62,7 +65,6 @@ op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc - DEPS op_desc tensor op_registry operator net_op) -cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) + DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index b886ded9bbd97dc1942c87d7603521e8d72e3f6c..8ab748ed71e9a5dc0ee0259a78a2b886870bec5b 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -18,17 +18,15 @@ namespace paddle { namespace operators { class AddOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "Inputs of AddOp must all be set"); - PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, - "Outputs of AddOp must all be set"); - PADDLE_ENFORCE(ctx.Input(0)->dims() == ctx.Input(1)->dims(), - "Two input of Add Op's dimension must be same."); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), + ctx.Input("Y")->dims(), + "Two input of Add Op's dimension must be same."); + ctx.Output("Out")->Resize(ctx.Input("X")->dims()); } }; @@ -48,7 +46,9 @@ The equation is: Out = X + Y }; class AddOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override {} }; @@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker); -REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index d76c10957e943deb970b1d79a1507a36669314e3..a7307b6818aa3d10ff215d06281e2b53196fd101 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,9 +28,9 @@ template class AddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto input0 = context.Input(0); - auto input1 = context.Input(1); - auto output = context.Output(0); + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Y"); + auto* output = context.Output("Out"); output->mutable_data(context.GetPlace()); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 09aa589d3caf7ed7b790850b515d49afdd3e1467..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -18,36 +18,31 @@ namespace paddle { namespace operators { class OnehotCrossEntropyOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, - "Input size of OnehotCrossEntropyOp must be two"); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, - "Output size of OnehotCrossEntropyOp must be one"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), - "0-th input of OnehotCrossEntropyOp should be set"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1), - "1-th input of OnehotCrossEntropyOp should be set"); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), - "Outputs of OnehotCrossEntropyOp must all be set"); - PADDLE_ENFORCE_EQ(ctx.Input(0)->dims().size(), 2); - PADDLE_ENFORCE_EQ(ctx.Output(0)->dims().size(), 1, - "label's dimension must be 1."); - ctx.Output(0)->Resize({ctx.Input(0)->dims()[0]}); + auto *X = ctx.Input("X"); + auto *label = ctx.Input("label"); + + PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); + PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); + PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); + ctx.Output("Y")->Resize({X->dims()[0]}); } }; class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp, - framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); + auto dX = ctx.Output(framework::GradVarName("X")); auto X = ctx.Input("X"); - // TODO(superjom) add enforce here after helper functions ready - X_grad->Resize(X->dims()); + dX->Resize(X->dims()); } }; @@ -72,12 +67,9 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); -REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); -REGISTER_OP_CPU_KERNEL( - onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOpKernel); + ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 4bbc8f093a794d46737a16488684a6a0cc25e285..d999bfce58c8a6db5c811aad677c07094b881841 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -12,10 +12,122 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/cross_entropy_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__host__ __device__ T clipping_log(const T x) { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + T v = log(x); + if (v == INFINITY) { + return kApproInf; + } + if (v == -INFINITY) { + return -kApproInf; + } + return v; +} + +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, + const int N, const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + PADDLE_ASSERT(label[i] >= 0 && label[i] < D); + Y[i] = -clipping_log(X[i * D + label[i]]); + } +} + +// TODO(qingqing): make zero setting an common function. +template +__global__ void zero(T* X, const int N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + X[i] = 0.0; + } +} + +template +__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, + const int* label, const int N, + const int D) { + // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. + // CUDA_1D_KERNEL_LOOP(i, N) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + int idx = i * D + label[i]; + dX[idx] = -dY[i] / X[idx]; + } +} + +template +class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + const T* Xdata = X->data(); + const int* label_data = ctx.Input("label")->data(); + auto Y = ctx.Output("Y"); + Y->mutable_data(ctx.GetPlace()); + T* Ydata = Y->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N + block - 1) / block; + // TODO(qingqing) launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyKernel<<>>(Ydata, Xdata, label_data, N, D); + } +}; + +template +class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + auto X = ctx.Input("X"); + auto dX = ctx.Output(framework::GradVarName("X")); + auto dY = ctx.Input(framework::GradVarName("Y")); + auto label = ctx.Input("label"); + + auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + auto* dYdata = dY->template data(); + auto* Xdata = X->template data(); + auto* label_data = label->data(); + + int N = X->dims()[0]; + int D = X->dims()[1]; + int block = 512; + int grid = (N * D + block - 1) / block; + zero<<>>(dXdata, N * D); + + grid = (N + block - 1) / block; + // TODO(qingqing): launch kernel on specified stream + // base on ExecutionContext. + CrossEntropyGradientKernel<<>>(dXdata, dYdata, Xdata, + label_data, N, D); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, + ops::OnehotCrossEntropyOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOpCUDAKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index d1bbc2cb66d6ce84ddcdcb87648f23c6ce77b748..eb4d1348de1d940e2648c83c8ba94b289f10c5b2 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -21,7 +21,7 @@ namespace operators { using Tensor = framework::Tensor; template -T tolerable_value(T x) { +inline T tolerable_value(const T x) { static_assert(std::is_floating_point::value, "tolerable_value works only on float, " "double and double double."); @@ -39,13 +39,16 @@ T tolerable_value(T x) { return x; } -template +template class OnehotCrossEntropyOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); const T* Xdata = X->data(); - const int* label_data = ctx.Input(1)->data(); + const int* label_data = ctx.Input("label")->data(); auto Y = ctx.Output("Y"); Y->mutable_data(ctx.GetPlace()); @@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel { } }; -template +template class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto X = ctx.Input("X"); auto dX = ctx.Output(framework::GradVarName("X")); auto dY = ctx.Input(framework::GradVarName("Y")); @@ -79,6 +85,8 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel { const int batch_size = X->dims()[0]; const int class_num = X->dims()[1]; + // TODO(qingqing): make zero setting an common function. + memset(dXdata, 0, sizeof(T) * batch_size * class_num); for (int i = 0; i < batch_size; ++i) { int index = i * class_num + label_data[i]; dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index eda23a0ccfacd3a620412876e18f4ec47652bf9d..9d51f6e3a16fe96125599bb440d40237aeb9a028 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -18,19 +18,13 @@ namespace paddle { namespace operators { class FillZerosLikeOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, - "Input size of FillZerosLikeOp must be one."); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL, - "Output size of AddOp must be one."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), - "Input of FillZerosLikeOp must be set."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), - "Output of FillZerosLikeOp must be set."); - ctx.Output(0)->Resize( - ctx.Input(0)->dims()); + ctx.Output("Dst")->Resize( + ctx.Input("Src")->dims()); } }; @@ -52,7 +46,8 @@ The output will have the same size with input. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp, + ops::FillZerosLikeOpMaker); REGISTER_OP_CPU_KERNEL( fill_zeros_like, ops::FillZerosLikeKernel); diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h index f846c7a8ab15e2cd997564edb36660a1360227a8..fd380ca8514b0ac50f39613368a4836bd485668b 100644 --- a/paddle/operators/fill_zeros_like_op.h +++ b/paddle/operators/fill_zeros_like_op.h @@ -23,7 +23,7 @@ template class FillZerosLikeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* output = context.Output(0); + auto* output = context.Output("Dst"); output->mutable_data(context.GetPlace()); auto t = framework::EigenVector::Flatten(*output); t.device(context.GetEigenDevice()) = t.constant(T(0)); diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 0c73717d38aca9f3430e66cafc3ecccdd2eec776..d6e6990394e46ba06c4bacfe33ca522f3ff1413a 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size, const int index_size, T* output) { const size_t slice_bytes = slice_size * sizeof(T); - for (size_t i = 0; i < index_size; ++i) { + for (int i = 0; i < index_size; ++i) { int index_ = indices[i]; memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); } @@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, // slice size int slice_size = 1; - for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; // Gathering if (platform::is_cpu_place(place)) { diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 5de748ec461e4b1a34b75b57c9cd7d5bc9326059..0ae1e99452973feb6d085dd6ef51e2afca988f59 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -35,7 +35,7 @@ TEST(Gather, GatherData) { p_src = src->mutable_data(make_ddim({3, 4}), CPUPlace()); p_index = index->mutable_data(make_ddim({2}), CPUPlace()); - for (size_t i = 0; i < 12; ++i) p_src[i] = i; + for (int i = 0; i < 12; ++i) p_src[i] = i; p_index[0] = 1; p_index[1] = 0; @@ -43,6 +43,10 @@ TEST(Gather, GatherData) { Gather(CPUPlace(), src, index, output); - for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); - for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); + for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); + for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); + + delete src; + delete index; + delete output; } diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 893cf56e5cf0d99d3f3bfffe98734a868f9b7595..a85363ad81d2a23e7267026c067f74f8c94c4786 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,34 +16,36 @@ namespace paddle { namespace operators { template -class GaussianRandomKernel : public framework::OpKernel { +class CPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { float mean = context.op_.GetAttr("mean"); float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - // TODO(dzh): attribute does not support unsigned int. - // And we need a global random seed configuration. - int seed = context.op_.GetAttr("seed"); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); + std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } - std::mt19937 g(seed); - std::normal_distribution distribution(mean, std); + engine.seed(seed); + std::normal_distribution dist(mean, std); ssize_t size = framework::product(tensor->dims()); - for (int i = 0; i < size; ++i) { - data[i] = distribution(g); + for (ssize_t i = 0; i < size; ++i) { + data[i] = dist(engine); } } }; class GaussianRandomOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext& context) const override { - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); auto dims = GetAttr>("dims"); PADDLE_ENFORCE(dims.size() > 0UL, "dims can be one int or array. dims must be set."); @@ -66,8 +65,8 @@ Use to initialize tensor with gaussian random generator. )DOC"); AddAttr>("dims", "The dimension of random tensor."); - AddAttr("mean", "mean value of random.").SetDefault(.0f); - AddAttr("std", "minimum value of random value.").SetDefault(1.0f); + AddAttr("mean", "mean of random tensor.").SetDefault(.0f); + AddAttr("std", "std of random tensor.").SetDefault(1.0f); AddAttr("seed", "Random seed of generator." "0 means use system wide seed") @@ -79,5 +78,6 @@ Use to initialize tensor with gaussian random generator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); -REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, + ops::GaussianRandomOpMaker); +REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel); diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 1340b1e1e9f19fd96ced9e57fab75fe9d33bc84e..018a4bfcb26b9008c054000c91edf01e371fd82b 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -1,53 +1,65 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/platform/dynload/curand.h" -#include "paddle/platform/gpu_info.h" - +#include +#include +#include +#include #include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { template -class GaussianRandomKernel : public framework::OpKernel { +struct GaussianGenerator { + T mean_, std_; + unsigned int seed_; + + __host__ __device__ GaussianGenerator(T mean, T std, int seed) + : mean_(mean), std_(std), seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::normal_distribution dist(mean_, std_); + rng.discard(n); + return dist(rng); + } +}; + +template +class GPUGaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - float mean = context.op_.GetAttr("mean"); - float std = context.op_.GetAttr("std"); - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - - int seed = context.op_.GetAttr("seed"); + unsigned int seed = + static_cast(context.op_.GetAttr("seed")); if (seed == 0) { std::random_device rd; seed = rd(); } - curandGenerator_t g; - PADDLE_ENFORCE(platform::dynload::curandCreateGenerator( - &g, CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - platform::dynload::curandGenerateNormal( - g, data, framework::product(tensor->dims()), mean, std); + T mean = static_cast(context.op_.GetAttr("mean")); + T std = static_cast(context.op_.GetAttr("std")); + thrust::counting_iterator index_sequence_begin(0); + ssize_t N = framework::product(tensor->dims()); + thrust::transform(index_sequence_begin, index_sequence_begin + N, + thrust::device_ptr(data), + GaussianGenerator(mean, std, seed)); } }; } // namespace operators } // namespace paddle -namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); +REGISTER_OP_GPU_KERNEL(gaussian_random, + paddle::operators::GPUGaussianRandomKernel); diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed51d416ed9497eee45ba826ad672b8fb1ad3678 --- /dev/null +++ b/paddle/operators/math/CMakeLists.txt @@ -0,0 +1,8 @@ + +if(WITH_GPU) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) +else() + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) +endif() + +nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e86fc3d166077265e0f433a6712b0665ea5a152 --- /dev/null +++ b/paddle/operators/math/math_function.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, + float beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), + "Matrix must all be in CPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), + "Matrix must all be in CPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu new file mode 100644 index 0000000000000000000000000000000000000000..da40b27c948918e4997f4a046d2145552296158b --- /dev/null +++ b/paddle/operators/math/math_function.cu @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, + float beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in GPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in GPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h new file mode 100644 index 0000000000000000000000000000000000000000..155589fadb3ed9f59160a750d546dd8093a56cbe --- /dev/null +++ b/paddle/operators/math/math_function.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +// Support continuous memory now +// If transA = N, and transB = N +// Then matrixA: M * K, matrixB: K * N matrixC : M * N +// For more detailed info, please refer to +// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html +template +void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, + const int M, const int N, const int K, const T alpha, const T* A, + const T* B, const T beta, T* C, platform::DeviceContext* context); + +// matrix multiply with continuous memory +template +void matmul(const framework::Tensor& matrix_a, bool trans_a, + const framework::Tensor& matrix_b, bool trans_b, T alpha, + framework::Tensor* matrix_out, T beta, + platform::DeviceContext* context); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c020c4ff7285b43bc5836d80c173d3a068e72b3 --- /dev/null +++ b/paddle/operators/math/math_function_test.cc @@ -0,0 +1,75 @@ +#include "paddle/operators/math/math_function.h" +#include "gtest/gtest.h" + +#ifndef PADDLE_ONLY_CPU +TEST(math_function, notrans_mul_trans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({2, 2}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 5); + EXPECT_EQ(out_ptr[1], 14); + EXPECT_EQ(out_ptr[2], 14); + EXPECT_EQ(out_ptr[3], 50); +} + +TEST(math_function, trans_mul_notrans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({3, 3}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 9); + EXPECT_EQ(out_ptr[1], 12); + EXPECT_EQ(out_ptr[2], 15); + EXPECT_EQ(out_ptr[3], 12); + EXPECT_EQ(out_ptr[4], 17); + EXPECT_EQ(out_ptr[5], 22); + EXPECT_EQ(out_ptr[6], 15); + EXPECT_EQ(out_ptr[7], 22); + EXPECT_EQ(out_ptr[8], 29); +} +#endif diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index f6abba7ab45728f74dcea1363035a729b2cd1d03..d3d0e55a674587fb04f43f24d0790de4358f035a 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -18,14 +18,14 @@ namespace paddle { namespace operators { class MeanOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one"); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of AddOp must be one"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "input should be set"); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), "output should be set"); - ctx.Output(0)->Resize(framework::make_ddim({1})); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input of MeanOp must be initialized."); + ctx.Output("Out")->Resize({1}); } }; @@ -34,13 +34,15 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").IgnoreGradient(); + AddOutput("Out", "The output of mean op").NotInGradient(); AddComment("Mean Operator"); } }; class MeanGradOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { ctx.Output(framework::GradVarName("X")) @@ -52,9 +54,8 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); -REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index e8595a14faa7c1b03734f814c78f9cbf1819fbb5..9848af280b62729bef9243052ceae0b7d8f4c6f5 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -31,14 +31,14 @@ template class MeanKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input(0); - auto output = context.Output(0); + auto* input = context.Input("X"); + auto* output = context.Output("Out"); output->mutable_data(context.GetPlace()); auto X = EigenVector::Flatten(*input); auto y = EigenScalar::From(*output); - auto place = context.GetEigenDevice(); + auto& place = context.GetEigenDevice(); y.device(place) = X.mean(); } @@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel { IG->mutable_data(context.GetPlace()); T ig_size = (T)framework::product(IG->dims()); + Eigen::DSizes bcast(ig_size); EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = - EigenScalar::From(*OG) / ig_size; + (EigenVector::From(*OG) / ig_size).broadcast(bcast); } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 6115a3f3332dba419b56e74a737627483448a715..173cc3850ca9d97200e272ec59d1bd3fe09b5053 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -17,13 +17,16 @@ namespace paddle { namespace operators { +using framework::Tensor; + class MulOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); - auto dim0 = ctx.Input(0)->dims(); - auto dim1 = ctx.Input(1)->dims(); + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Input("Y")->dims(); PADDLE_ENFORCE_EQ(dim0.size(), 2, "input X(%s) should be a tensor with 2 dims, a matrix", ctx.op_.Input("X")); @@ -33,8 +36,7 @@ class MulOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( dim0[1], dim1[0], "First matrix's width must be equal with second matrix's height."); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "The mul op takes only one output"); - ctx.Output(0)->Resize({dim0[0], dim1[1]}); + ctx.Output("Out")->Resize({dim0[0], dim1[1]}); } }; @@ -54,12 +56,27 @@ The equation is: Out = X * Y }; class MulOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "MulGrad"; - return ""; + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + PADDLE_ENFORCE(x_dims[0] == out_dims[0], + "Out@GRAD M X N must equal to X dims 0, M "); + PADDLE_ENFORCE(y_dims[1] == out_dims[1], + "Out@GRAD M X N must equal to Y dims 1, N "); + + x_grad->Resize(x_dims); + y_grad->Resize(y_dims); } }; @@ -67,7 +84,7 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); -REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad); - +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 43debbc21a365a15c914e60e151f7782b82080cb..a81444dbe63edeecedc5d822c65ff56c42b5db90 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -16,5 +16,6 @@ #include "paddle/operators/mul_op.h" namespace ops = paddle::operators; - REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ab12631c03453a18fbb067e2d12c2bc332acd567..8facc0281449785bf40726f23ca2fd5d166ff272 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -13,6 +13,9 @@ limitations under the License. */ #pragma once + +#include "paddle/operators/math/math_function.h" + #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -28,21 +31,34 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - - auto input0 = context.Input("X"); - auto input1 = context.Input("Y"); - auto output = context.Output(0); - - output->mutable_data(context.GetPlace()); + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + } +}; - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto place = context.GetEigenDevice(); +template +class MulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); - Z.device(place) = X.contract(Y, dim_pair); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + dX->mutable_data(ctx.GetPlace()); + dY->mutable_data(ctx.GetPlace()); + auto* device_context = + const_cast(ctx.device_context_); + // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N + math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K + math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); } }; diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index a466c4f30fe87db4ad2a44518e083b57f3cbc2ed..a7d710511093dfbe13a13b1222b0230bba0398bd 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -15,48 +15,42 @@ */ #include "paddle/operators/net_op.h" +#include +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { +const char NetOp::kAll[] = "all"; + void NetOp::CompleteAddOp(bool calc) { add_op_done_ = true; if (!calc) return; - std::unordered_set input_set; - std::unordered_set output_set; - std::unordered_set temp_output; + std::set input_set; + std::set output_set; for (auto& op : ops_) { - for (auto& ipt : op->inputs_) { - if (!Contains(output_set, ipt)) { // Not other op's output - input_set.insert(ipt); - } else { - temp_output.insert(ipt); + for (auto& ipt : op->Inputs()) { + for (auto& var_name : ipt.second) { + if (!Contains(output_set, var_name)) { // Not other op's output + input_set.insert(var_name); + } else { + intermediate_outputs_.insert(var_name); + } } } - for (auto& opt : op->outputs_) { - output_set.insert(opt); - } - } - - inputs_.reserve(input_set.size()); - std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_)); - std::sort(inputs_.begin(), inputs_.end()); - - outputs_.reserve(output_set.size()); - std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs_)); - std::sort(outputs_.begin(), outputs_.end()); - - std::vector tmp_index; - tmp_index.reserve(temp_output.size()); - int output_len = static_cast(outputs_.size()); - for (int i = 0; i < output_len; ++i) { - if (Contains(temp_output, outputs_[i])) { - tmp_index.push_back(i); + for (auto& opt : op->Outputs()) { + for (auto& var_name : opt.second) { + output_set.insert(var_name); + } } } - - attrs_["temporary_index"] = tmp_index; + auto& inputs = inputs_[kAll]; + inputs.reserve(input_set.size()); + std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs)); + auto& outputs = outputs_[kAll]; + outputs.reserve(output_set.size()); + std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs)); } std::string NetOp::DebugString() const { @@ -73,5 +67,32 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } +std::vector NetOp::OutputVars(bool has_intermediate) const { + if (has_intermediate) { + return this->outputs_.at(kAll); + } + auto& all = this->outputs_.at(kAll); + std::vector ret_val; + for (auto& each : all) { + if (!Contains(intermediate_outputs_, each)) { + ret_val.push_back(each); + } + } + return ret_val; +} + +NetOp::NetOp(const std::string& type, + const framework::OperatorBase::VarNameMap& inputs, + const framework::OperatorBase::VarNameMap& outputs, + const framework::AttributeMap& attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) {} + +std::unique_ptr NetOp::Clone() const { + PADDLE_ENFORCE( + add_op_done_, + "Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"); + return std::unique_ptr(new NetOp(*this)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/net_op.h b/paddle/operators/net_op.h index 24c9e61c66933c6be5bf44b3537e00b70a33922f..3d3f996ef52b6c1136425ca9de0f60e7e155458f 100644 --- a/paddle/operators/net_op.h +++ b/paddle/operators/net_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/framework/framework.pb.h" #include "paddle/framework/op_registry.h" namespace paddle { @@ -35,7 +36,20 @@ namespace operators { */ class NetOp : public framework::OperatorBase { public: - DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase) + static const char kAll[]; + NetOp() : framework::OperatorBase("plain_net", {}, {}, {}) {} + NetOp(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const framework::AttributeMap& attrs); + + NetOp(const NetOp& o) : framework::OperatorBase(o.type_, {}, {}, o.attrs_) { + this->ops_.reserve(o.ops_.size()); + std::transform( + o.ops_.begin(), o.ops_.end(), std::back_inserter(this->ops_), + [](const std::unique_ptr& op) { + return std::unique_ptr(op->Clone()); + }); + this->CompleteAddOp(); + } /** * Infer all the operators' input and output variables' shapes, will be called @@ -70,21 +84,28 @@ class NetOp : public framework::OperatorBase { return true; } + void AppendOp(const framework::OperatorBase& op) { AppendOp(op.Clone()); } + /** * @brief Add an operator by ptr */ - void AddOp(const std::shared_ptr& op) { - PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + void AppendOp(std::unique_ptr op) { + PADDLE_ENFORCE(!add_op_done_, + "Cannot AppendOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); - ops_.push_back(op); + ops_.push_back(std::move(op)); } - void InsertOp(size_t pos, const std::shared_ptr& op) { + void InsertOp(size_t pos, std::unique_ptr op) { PADDLE_ENFORCE(!add_op_done_, "Cannot InsertOp when this network is sealed"); PADDLE_ENFORCE_NOT_NULL(op, "Cannot Insert Null op"); PADDLE_ENFORCE_LE(pos, ops_.size(), "Out of range"); - ops_.insert(ops_.begin() + pos, op); + ops_.insert(ops_.begin() + pos, std::move(op)); + } + + void InsertOp(size_t pos, const framework::OperatorBase& op) { + InsertOp(pos, op.Clone()); } void CompleteAddOp(bool calculate = true); @@ -92,11 +113,15 @@ class NetOp : public framework::OperatorBase { std::string DebugString() const override; bool IsNetOp() const override; + std::vector OutputVars(bool has_intermediate) const override; + + std::unique_ptr Clone() const override; - std::vector> ops_; + std::vector> ops_; private: bool add_op_done_{false}; + std::set intermediate_outputs_; template static bool Contains(T container, KeyType key) { diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 0d5c3de798d0b580860d24ea9a61a6a4ede5d0ab..99019754a965e5e7aeb74c6bfc10c9646289651b 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -12,8 +12,8 @@ static int run_cnt = 0; class TestOp : public framework::OperatorBase { public: - DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase) - + using framework::OperatorBase::OperatorBase; + DEFINE_OP_CLONE_METHOD(TestOp); void InferShape(const Scope& scope) const override { ++infer_shape_cnt; } void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override { @@ -21,14 +21,6 @@ class TestOp : public framework::OperatorBase { } }; -class EmptyOp : public framework::OperatorBase { - public: - DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase) - - void InferShape(const Scope& scope) const override {} - void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} -}; - template void AssertSameVectorWithoutOrder(const std::vector& expected, const std::vector& actual) { @@ -46,46 +38,51 @@ TEST(OpKernel, all) { auto net = std::make_shared(); ASSERT_NE(net, nullptr); - auto op1 = std::make_shared(); - op1->inputs_ = {"x", "w1", "b1"}; - op1->outputs_ = {"y"}; - net->AddOp(op1); - - auto op2 = std::make_shared(); - op2->inputs_ = {"y", "w2", "b2"}; - op2->outputs_ = {"z"}; - net->AddOp(op2); + net->AppendOp(std::unique_ptr( + new TestOp("test", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, + {{"Out", {"y"}}}, {}))); + net->AppendOp(std::unique_ptr( + new TestOp("test", {{"X", {"y"}}, {"W", {"w2"}}, {"b", {"b2"}}}, + {{"Out", {"z"}}}, {}))); net->CompleteAddOp(); - AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); - AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); - auto tmp_idx_iter = net->attrs_.find("temporary_index"); - ASSERT_NE(net->attrs_.end(), tmp_idx_iter); - auto& tmp_idx = boost::get>(tmp_idx_iter->second); - ASSERT_EQ(1UL, tmp_idx.size()); - ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); + AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, + net->Inputs(NetOp::kAll)); + AssertSameVectorWithoutOrder({"y", "z"}, net->Outputs(NetOp::kAll)); - Scope scope; - platform::CPUDeviceContext dev_ctx; + auto final_outs = net->OutputVars(false); - net->InferShape(scope); - net->Run(scope, dev_ctx); - ASSERT_EQ(2, infer_shape_cnt); - ASSERT_EQ(2, run_cnt); - ASSERT_THROW(net->AddOp(op2), platform::EnforceNotMet); + ASSERT_EQ(final_outs.size(), 1UL); + ASSERT_EQ(final_outs[0], "z"); } TEST(NetOp, insert_op) { NetOp net; - auto op1 = std::make_shared(); - op1->inputs_ = {"x", "w1", "b1"}; - op1->outputs_ = {"y"}; - net.AddOp(op1); - net.InsertOp(0, op1); + auto op1 = std::unique_ptr( + new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, + {{"Out", {"y"}}}, {})); + net.AppendOp(*op1); + net.InsertOp(0, *op1); ASSERT_EQ(2UL, net.ops_.size()); - net.InsertOp(2, op1); + net.InsertOp(2, std::move(op1)); ASSERT_EQ(3UL, net.ops_.size()); } +TEST(NetOp, Clone) { + NetOp net; + net.AppendOp( + std::unique_ptr(new framework::NOP{"empty", {}, {}, {}})); + net.AppendOp(std::unique_ptr( + new framework::NOP{"empty2", {}, {}, {}})); + net.CompleteAddOp(true); + auto new_net_op = net.Clone(); + ASSERT_NE(new_net_op, nullptr); + ASSERT_TRUE(new_net_op->IsNetOp()); + auto* new_net = static_cast(new_net_op.get()); + ASSERT_EQ(2, new_net->ops_.size()); + ASSERT_EQ(new_net->ops_[0]->Type(), "empty"); + ASSERT_EQ(new_net->ops_[1]->Type(), "empty2"); +} + } // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 243837420562634c3d99fd0acf234ebd53539735..78ce0ba3c0fa4fe380e49a848c2434fe593cd00b 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); InitMemories(step_scopes[0], true /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (size_t i = 0; i < seq_len_; i++) { if (i > 0) { rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } - net->GetMutable()->InferShape(*step_scopes[i]); + (*stepnet_)->InferShape(*step_scopes[i]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); InitMemories(step_scopes[0], false /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); for (size_t step_id = 0; step_id < seq_len_; step_id++) { // create output alias variables @@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } - net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); @@ -78,25 +75,28 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { auto step_scopes = step_scopes_var->GetMutable>(); // Now all variables in scope must be created outside of op. - auto net_var = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope", - arg_->step_net); - auto net_op = net_var->GetMutable(); - PADDLE_ENFORCE(!net_op->outputs_.empty(), "net_op has no outputs"); + PADDLE_ENFORCE_NOT_NULL(stepnet_); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs"); if (seq_len_ > step_scopes->size()) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs - for (auto& input : net_op->inputs_) { + for (auto& input : (*stepnet_)->Inputs()) { // the weight are located in parent scope - if (!step_scope.FindVar(input)) - step_scope.NewVar(input)->GetMutable(); + for (auto& var_name : input.second) { + if (!step_scope.FindVar(var_name)) { + step_scope.NewVar(var_name)->GetMutable(); + } + } } // create stepnet's outputs - for (const auto& output : net_op->outputs_) { - step_scope.NewVar(output); + for (const auto& output : (*stepnet_)->Outputs()) { + for (auto& var_name : output.second) { + step_scope.NewVar(var_name); + } } step_scopes->emplace_back(&step_scope); } @@ -130,11 +130,13 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{ "inlink@grad", "inlink_alias", "outlink_alias", "memories", "pre_memories", "boot_memories@grad"}; -void RecurrentOp::Init() { - OperatorBase::Init(); - std::unique_ptr arg(new rnn::Argument()); - rnn::InitArgument(kArgName, arg.get(), *this); - alg_.Init(std::move(arg)); +RecurrentOp::RecurrentOp(const std::string& type, + const framework::OperatorBase::VarNameMap& inputs, + const framework::OperatorBase::VarNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) { + rnn::InitArgument(kArgName, &arg_, *this); + alg_.Init(&arg_, &stepnet_); } class RecurrentAlgorithmProtoAndCheckerMaker @@ -147,13 +149,12 @@ class RecurrentAlgorithmProtoAndCheckerMaker // inputs and outputs stored in proto AddInput(name.inlinks, "the inputs that need to be segmented for each step.") - .SetMultiple(); + .AsDuplicable(); AddInput(name.boot_memories, "variables to initialize memories.") - .SetMultiple(); - AddInput(name.step_net, "network shared by all steps."); + .AsDuplicable(); AddOutput(name.outlinks, "the outputs that need to concated for all steps.") - .SetMultiple(); + .AsDuplicable(); AddOutput(name.step_scopes, "step scopes"); // Attributes stored in AttributeMap @@ -172,14 +173,12 @@ void RecurrentGradientAlgorithm::Run( auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } - net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } LinkBootMemoryGradients(step_scopes[0], false); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, @@ -211,29 +210,30 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } - net->GetMutable()->InferShape(*step_scopes[step_id]); + (*stepnet_)->InferShape(*step_scopes[step_id]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); } -void RecurrentGradientOp::Init() { - OperatorBase::Init(); - std::unique_ptr arg(new rnn::Argument()); - rnn::InitArgument(kArgName, arg.get(), *this); - alg_.Init(std::move(arg)); +RecurrentGradientOp::RecurrentGradientOp( + const std::string& type, const framework::OperatorBase::VarNameMap& inputs, + const framework::OperatorBase::VarNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) { + rnn::InitArgument(kArgName, &arg_, *this); + alg_.Init(&arg_, &stepnet_); } } // namespace operators } // namespace paddle -REGISTER_OP(recurrent_op, paddle::operators::RecurrentOp, - paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + recurrent_op, paddle::operators::RecurrentOp, + paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index fdd9d005378e63b8d44803fb2b4be83d134c6a5b..bcfa817de8242153b164fa091309f19a6ad8a246 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/framework/operator.h" +#include "paddle/operators/net_op.h" #include "paddle/operators/rnn/recurrent_op_utils.h" namespace paddle { @@ -33,7 +34,12 @@ class RecurrentAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + void Init(rnn::Argument* arg, + std::unique_ptr* stepnet) { + PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); + arg_ = arg; + stepnet_ = stepnet; + } /** * InferShape must be called before Run. @@ -58,7 +64,8 @@ class RecurrentAlgorithm { void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: - std::unique_ptr arg_; + std::unique_ptr* stepnet_; + rnn::Argument* arg_; mutable size_t seq_len_; }; @@ -74,7 +81,12 @@ class RecurrentGradientAlgorithm { * operator. */ public: - void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + void Init(rnn::Argument* arg, + std::unique_ptr* stepnet) { + PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); + arg_ = std::move(arg); + stepnet_ = stepnet; + } void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; @@ -95,15 +107,22 @@ class RecurrentGradientAlgorithm { } private: - std::unique_ptr arg_; + rnn::Argument* arg_; mutable size_t seq_len_; + std::unique_ptr* stepnet_; }; -class RecurrentOp final : public framework::OperatorBase { - DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase) +class RecurrentOp : public framework::OperatorBase { public: - void Init() override; - + RecurrentOp(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, const framework::AttributeMap& attrs); + + RecurrentOp(const RecurrentOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement copy ctor well. + PADDLE_THROW("Not implemented"); + } /** * InferShape must be called before Run. */ @@ -116,15 +135,31 @@ class RecurrentOp final : public framework::OperatorBase { alg_.Run(scope, dev_ctx); } + void set_stepnet(std::unique_ptr net) { + stepnet_ = std::move(net); + } + const OperatorBase& stepnet() const { return *stepnet_; } + static const rnn::ArgumentName kArgName; private: RecurrentAlgorithm alg_; + rnn::Argument arg_; + std::unique_ptr stepnet_; }; -class RecurrentGradientOp final : public framework::OperatorBase { +class RecurrentGradientOp : public framework::OperatorBase { public: - void Init() override; + RecurrentGradientOp(const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const framework::AttributeMap& attrs); + + RecurrentGradientOp(const RecurrentGradientOp& o) + : framework::OperatorBase( + static_cast(o)) { + // TODO(yuyang18): Implement Copy ctor. + PADDLE_THROW("Not Implemented"); + } /** * InferShape must be called before Run. @@ -140,8 +175,15 @@ class RecurrentGradientOp final : public framework::OperatorBase { static const rnn::ArgumentName kArgName; + void set_stepnet(std::unique_ptr net) { + stepnet_ = std::move(net); + } + const OperatorBase& stepnet() const { return *stepnet_; } + private: RecurrentGradientAlgorithm alg_; + std::unique_ptr stepnet_; + rnn::Argument arg_; }; } // namespace operators diff --git a/paddle/operators/recurrent_op_test.cc b/paddle/operators/recurrent_op_test.cc deleted file mode 100644 index 0c9a343415835540c7543f15f40c53b78a6a55c4..0000000000000000000000000000000000000000 --- a/paddle/operators/recurrent_op_test.cc +++ /dev/null @@ -1,398 +0,0 @@ -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "paddle/operators/recurrent_op.h" - -#include -#include - -#include "paddle/framework/ddim.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/tensor.h" -#include "paddle/operators/net_op.h" - -namespace paddle { -namespace operators { - -using framework::make_ddim; -using framework::DDim; -using framework::Tensor; -using framework::Variable; -using framework::Scope; -using framework::OpRegistry; - -class RecurrentOpTest : public ::testing::Test { - protected: - virtual void SetUp() override { - CreateGlobalVariables(); - CreateStepNet(); - CreateRNNOp(); - } - - virtual void TearDown() override {} - - void CreateGlobalVariables() { - // create input, and init content - LOG(INFO) << "create global variable x"; - for (auto inlink : std::vector{"x", "x0", "x1", "h"}) { - Variable* x = scope_.NewVar(inlink); - DDim dims = make_ddim(std::vector{ - 10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); - x->GetMutable()->mutable_data(dims, platform::CPUPlace()); - } - // create output alias just for test - for (auto inlink : std::vector{"h@alias"}) { - Variable* x = scope_.NewVar(inlink); - DDim dims = - make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}); - x->GetMutable()->mutable_data(dims, platform::CPUPlace()); - } - - LOG(INFO) << "create global variable w"; - Variable* w = scope_.NewVar("rnn/w"); - w->GetMutable()->mutable_data( - make_ddim(std::vector{30, 30}), platform::CPUPlace()); - - for (auto boot : std::vector{"h_boot"}) { - LOG(INFO) << "create global variable " << boot; - Variable* h_boot = scope_.NewVar(boot); - h_boot->GetMutable()->mutable_data( - make_ddim(std::vector{20 /*batch size*/, 30 /*input dim*/}), - platform::CPUPlace()); - } - - LOG(INFO) << "create variable step_scopes"; - scope_.NewVar("step_scopes"); - - LOG(INFO) << "create variable h"; - scope_.NewVar("h"); - } - - void CreateRNNOp() { - framework::OpDesc op_desc; - - op_desc.set_type("recurrent_op"); - // inlinks 0 - op_desc.add_inputs("x"); - op_desc.add_inputs("x0"); - op_desc.add_inputs("x1"); - // boot_memories 3 - op_desc.add_inputs("h_boot"); - // step net 5 - op_desc.add_inputs("step_net"); - // outlinks 6 - op_desc.add_outputs("h"); - // step scopes 7 - op_desc.add_outputs("step_scopes"); - - auto _input_format = std::vector{ - 0, // in_link - 3, // memories - 4 // step_net - }; - auto input_format = op_desc.add_attrs(); - input_format->set_name("input_format"); - input_format->set_type(paddle::framework::AttrType::INTS); - for (auto i : _input_format) { - input_format->add_ints(i); - } - - auto output_format = op_desc.add_attrs(); - output_format->set_name("output_format"); - output_format->set_type(paddle::framework::AttrType::INTS); - for (auto i : std::vector{0, 1, 2}) { - output_format->add_ints(i); - } - - auto inlink_alias = op_desc.add_attrs(); - inlink_alias->set_name("inlink_alias"); - inlink_alias->set_type(paddle::framework::AttrType::STRINGS); - - auto outlink_alias = op_desc.add_attrs(); - outlink_alias->set_name("outlink_alias"); - outlink_alias->set_type(paddle::framework::AttrType::STRINGS); - - auto pre_memories = op_desc.add_attrs(); - pre_memories->set_name("pre_memories"); - pre_memories->set_type(paddle::framework::AttrType::STRINGS); - - auto memories = op_desc.add_attrs(); - memories->set_name("memories"); - memories->set_type(paddle::framework::AttrType::STRINGS); - - // create inlink_alias - for (const auto& item : - std::vector{"x@alias", "x0@alias", "x1@alias"}) { - inlink_alias->add_strings(item); - } - // pre memories - for (const auto& item : std::vector{"rnn/h@pre"}) { - pre_memories->add_strings(item); - } - // memories - for (const auto& item : std::vector{"rnn/h"}) { - memories->add_strings(item); - } - // output alias - for (const auto& item : std::vector{"h@alias"}) { - outlink_alias->add_strings(item); - } - - rnn_op_ = OpRegistry::CreateOp(op_desc); - - LOG(INFO) << "rnn_op finish init"; - } - - void CreateStepNet() { - LOG(INFO) << "create variable step_net"; - Variable* var = scope_.NewVar("step_net"); - auto net = var->GetMutable(); - net->AddOp( - OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); - - net->AddOp( - OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {})); - net->CompleteAddOp(); - } - - // father scope - Scope scope_; - std::shared_ptr rnn_op_; -}; - -TEST_F(RecurrentOpTest, Run) { - platform::CPUDeviceContext ctx; - rnn_op_->InferShape(scope_); - rnn_op_->Run(scope_, ctx); -} - -class RecurrentGradientAlgorithmTest : public ::testing::Test { - protected: - virtual void SetUp() override { - CreateGlobalVariables(); - CreateStepScopes(); - CreateStepNet(); - CreateRNNGradientAlgorithm(); - - // segment inputs - SegmentInputs(); - // link forward memories - LinkeMemories(); - } - - virtual void TearDown() override {} - - void CreateGlobalVariables() { - // inputs: x - LOG(INFO) << "create global variable x"; - Variable* x = scope_.NewVar("x"); - DDim dims = - make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); - x->GetMutable()->mutable_data(dims, platform::CPUPlace()); - // inputs: h_boot - LOG(INFO) << "create global variable h_boot"; - Variable* h_boot = scope_.NewVar("h_boot"); - h_boot->GetMutable()->mutable_data( - make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace()); - // inputs: w - LOG(INFO) << "create global variable w"; - Variable* w = scope_.NewVar("rnn/w"); - w->GetMutable()->mutable_data(make_ddim({30, 30}), - platform::CPUPlace()); - // inputs: h_grad - LOG(INFO) << "create variable h_grad"; - Variable* dh = scope_.NewVar("h_grad"); - dh->GetMutable()->mutable_data(make_ddim({10, 20, 30}), - platform::CPUPlace()); - // inputs: step_scopes - LOG(INFO) << "create variable step_scopes"; - scope_.NewVar("step_scopes"); - // inputs: step_net - LOG(INFO) << "create variable step_net"; - scope_.NewVar("step_net"); - // outputs: w_grad - LOG(INFO) << "create global variable w_grad"; - scope_.NewVar("rnn/w_grad"); - // outputs: x_grad - LOG(INFO) << "create global variable x_grad"; - scope_.NewVar("x_grad"); - // outputs: h_boot_grad - LOG(INFO) << "create global variable h_boot_grad"; - scope_.NewVar("h_boot_grad"); - } - - void CreateStepScopes() { - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - for (int i = 0; i < 10; ++i) { - auto& scope = scope_.NewScope(); - auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable(); - pre_t->mutable_data({20, 30}, platform::CPUPlace()); - auto tensor = scope.NewVar("rnn/h")->GetMutable(); - tensor->mutable_data({20, 30}, platform::CPUPlace()); - - // for unit test of ConcatOutputs - auto xg = scope.NewVar("rnn/x_grad")->GetMutable(); - xg->mutable_data({20, 30}, platform::CPUPlace()); - - step_scopes->emplace_back(&scope); - } - - // last time step - auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable(); - g->mutable_data({20, 30}, platform::CPUPlace()); - } - - void CreateRNNGradientAlgorithm() { - std::unique_ptr arg(new rnn::Argument()); - arg->step_net = "step_net"; - arg->step_scopes = "step_scopes"; - rnn::Link inlink; - inlink.external = "h_grad"; - inlink.internal = "rnn/h_grad"; - arg->inlinks = std::vector{inlink}; - - rnn::Link outlink; - outlink.external = "x_grad"; - outlink.internal = "rnn/x_grad"; - arg->outlinks = std::vector{outlink}; - - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "rnn/h_pre_grad"; - mem_attr.var = "rnn/h_grad"; - mem_attr.boot_var = "h_boot_grad"; - arg->memories = std::vector{mem_attr}; - - rnn_grad_algo_.Init(std::move(arg)); - } - - void CreateStepNet() { - LOG(INFO) << "create variable step_net"; - Variable* var = scope_.NewVar("step_net"); - auto net = var->GetMutable(); - net->AddOp(OpRegistry::CreateOp("mul", {"rnn/h_pre", "rnn/w", "rnn/s_grad"}, - {"rnn/h_pre_grad", "rnn/w_grad"}, {})); - - net->AddOp(OpRegistry::CreateOp("add_two", {"rnn/h_grad"}, - {"rnn/x_grad", "rnn/s_grad"}, {})); - net->CompleteAddOp(); - } - - void SegmentInputs() { - LOG(INFO) << "segment inputs"; - std::vector inlinks = {"x"}; - std::vector inlinks_alias = {"rnn/x"}; - - rnn::Link inlink; - inlink.external = "x"; - inlink.internal = "rnn/x"; - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10, - true /*infer_shape_mode*/); - } - - void LinkeMemories() { - LOG(INFO) << "link memories"; - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "rnn/h_pre"; - mem_attr.var = "rnn/h"; - mem_attr.boot_var = "boot_h"; - std::vector memories; - memories.push_back(mem_attr); - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - for (int i = 1; i < 10; ++i) { - rnn::LinkMemories(*step_scopes, memories, i, -1, - true /*infer_shape_mode*/); - } - } - - Scope scope_; - RecurrentGradientAlgorithm rnn_grad_algo_; -}; - -// TEST_F(RecurrentGradientAlgorithmTest, Run) { -// platform::CPUDeviceContext ctx; -// rnn_grad_algo_.Run(scope_, ctx); -// } - -} // namespace operators -} // namespace paddle - -TEST(RecurrentOp, LinkMemories) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::operators; - - // create and init step scopes - size_t len = 10; - std::vector step_scopes; - for (size_t i = 0; i < len; ++i) { - auto scope = new Scope(); - scope->NewVar("pre_h"); - auto tensor = scope->NewVar("h")->GetMutable(); - float* data = tensor->mutable_data({15, 20}, CPUPlace()); - for (size_t j = 0; j < 15 * 20; ++j) { - data[j] = rand() * (1. / (double)RAND_MAX); - } - step_scopes.push_back(scope); - } - - // create MemoryAttr - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "pre_h"; - mem_attr.var = "h"; - mem_attr.boot_var = "boot_h"; - std::vector memories; - memories.push_back(mem_attr); - - for (size_t i = 1; i < len; ++i) { - rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/); - } - // check - for (size_t i = 0; i < len - 1; ++i) { - const float* a = - step_scopes[i]->FindVar("h")->GetMutable()->data(); - const float* b = step_scopes[i + 1] - ->FindVar("pre_h") - ->GetMutable() - ->data(); - for (size_t j = 0; j < 15 * 20; ++j) { - ASSERT_FLOAT_EQ(a[j], b[j]); - } - } - - for (int i = len - 2; i >= 0; --i) { - rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/); - } - // check - for (int i = len - 2; i >= 0; --i) { - const float* a = - step_scopes[i]->FindVar("pre_h")->GetMutable()->data(); - const float* b = - step_scopes[i + 1]->FindVar("h")->GetMutable()->data(); - for (size_t j = 0; j < 15 * 20; ++j) { - ASSERT_FLOAT_EQ(a[j], b[j]); - } - } - - for (auto s : step_scopes) { - delete s; - } -} - -USE_OP(add_two); -USE_OP(mul); -USE_OP_ITSELF(recurrent_op); diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index 7e4770630ed2a49214194689aa489e6ab8e476da..a9b65c30f25554e54e9fd7103f240946a93566e2 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -106,7 +106,6 @@ void LinkMemories(const std::vector& scopes, void InitArgument(const ArgumentName& name, Argument* arg, const framework::OperatorBase& op) { - arg->step_net = op.Input(name.step_net); arg->step_scopes = op.Output(name.step_scopes); auto inlinks = op.Inputs(name.inlinks); diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 402f6340a04d9b423bb16431a99a2f2866d203bc..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -17,26 +17,28 @@ namespace paddle { namespace operators { -class RowWiseAddOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel) +using framework::Tensor; + +class RowwiseAddOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 2UL, - "Two inputs is needed by rowwise add"); - auto dim0 = ctx.Input(0)->dims(); - auto dim1 = ctx.Input(1)->dims(); + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Input("b")->dims(); PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); - PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1"); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1"); + ctx.Output("Out")->Resize(ctx.Input("X")->dims()); } }; -class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { +class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowWiseAddOpMaker(framework::OpProto *proto, + RowwiseAddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); @@ -49,11 +51,32 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowwiseAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dims0 = ctx.Input("X")->dims(); + auto dims1 = ctx.Input("b")->dims(); + PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + ctx.Output(framework::GradVarName("X"))->Resize(dims0); + ctx.Output(framework::GradVarName("b"))->Resize(dims1); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); +REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, + rowwise_add_grad, ops::RowwiseAddGradOp); +REGISTER_OP_CPU_KERNEL( + rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add_grad, + ops::RowwiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 86f80b81228a69ac4c05a4693901570f2b9966e0..cbc61ad3e117fc79a674ca21831d3fec59d1ec5b 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -17,4 +17,4 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add, ops::RowwiseAddKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 82e9d70e959441869b958c1241fa5f5beef4c50c..1cbd8bb31ad90a32d8a4e3bb59617d0b5384e470 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -28,14 +28,14 @@ template ; template -class RowWiseAddKernel : public framework::OpKernel { +class RowwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto out = context.Output(0); + auto out = context.Output("Out"); out->mutable_data(context.GetPlace()); - auto input = EigenMatrix::From(*context.Input(0)); - auto bias = EigenVector::From(*context.Input(1)); + auto input = EigenMatrix::From(*context.Input("X")); + auto bias = EigenVector::From(*context.Input("b")); auto output = EigenMatrix::From(*out); const int bias_size = bias.dimension(0); @@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel { } }; +template +class RowwiseAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + auto* db = context.Output(framework::GradVarName("b")); + dX->mutable_data(context.GetPlace()); + db->mutable_data(context.GetPlace()); + + auto OutGrad = EigenMatrix::From(*dOut); + auto place = context.GetEigenDevice(); + EigenMatrix::From(*dX).device(place) = OutGrad; + + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + // colwise add + Eigen::array dims{{0}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..6b542675c291607b35f180123cf42fee6a783a85 --- /dev/null +++ b/paddle/operators/scatter.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +// Implementation of CPU copy +template +void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index, + const size_t index_size, + paddle::framework::Tensor* output) { + paddle::framework::DDim output_dims = output->dims(); + + for (size_t i = 0; i < index_size; ++i) { + int index_ = index[i]; + + paddle::framework::Tensor src_ = *src; + paddle::framework::Tensor output_ = *output; + if (index_size > 1) src_ = src->Slice(i, i + 1); + if (output_dims[0] > 1) output_ = output->Slice(index_, index_ + 1); + + auto X = EigenVector::Flatten(src_); + auto Y = EigenVector::Flatten(output_); + + Y = X + Y; + } +} + +// Implementation of GPU scatter: +template +void GPUScatterUpdate(const T* src, const int* index, const int slice_size, + const int index_size, T* output); + +/** + * Return a updated tensor from source tensor, scattered according to index: + * dst[i] += src[index[i]] + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void ScatterUpdate(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + auto dst_dims = output->dims(); + + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); + + // slice size + size_t slice_size = 1; + for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + if (platform::is_cpu_place(place)) { + CPUScatterUpdate(src, index->data(), index_size, output); + } else { + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..26fdaff1460a297fa638181641991f732533fe52 --- /dev/null +++ b/paddle/operators/scatter_test.cc @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/scatter.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +#include +#include +#include + +TEST(scatter, ScatterUpdate) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + Tensor* src = new Tensor(); + Tensor* index = new Tensor(); + Tensor* output = new Tensor(); + + float* p_src = nullptr; + int* p_index = nullptr; + p_src = src->mutable_data(make_ddim({1, 4}), CPUPlace()); + p_index = index->mutable_data(make_ddim({1}), CPUPlace()); + + for (size_t i = 0; i < 4; ++i) p_src[i] = float(i); + p_index[0] = 1; + + float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); + + ScatterUpdate(CPUPlace(), src, index, output); + + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); + for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], float(i - 4)); + for (size_t i = 4; i < 8; ++i) + EXPECT_EQ(output->data()[i], float(i - 4)); + for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], float(0)); + for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data()[i], float(0)); + + delete src; + delete index; + delete output; +} diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 5b8093f0f77e0982a7ad25b42b299a6461712630..ad267e7f087943ff3b8326a7baf2ce3955fa51c2 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -18,17 +18,15 @@ namespace paddle { namespace operators { class SGDOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two"); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "Output size of SGDOp must be one"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), "inputs[0] mast be set"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(1), "inputs[1] mast be set"); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(0), "outputs[0] mast be set"); - PADDLE_ENFORCE(ctx.Input(0)->dims() == ctx.Input(1)->dims(), - "Two input of SGD Op's dimension must be same."); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + PADDLE_ENFORCE( + ctx.Input("param")->dims() == ctx.Input("grad")->dims(), + "Two input of SGD Op's dimension must be same."); + ctx.Output("param_out")->Resize(ctx.Input("param")->dims()); } }; @@ -53,6 +51,6 @@ param_out = param - learning_rate * grad; } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel); diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index bfb449d0b029409eda4177fc7643810ee6a1df3d..a0b5000ffbf54364e15f87870913926a071fa972 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -30,7 +30,7 @@ class SGDOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto param = ctx.Input("param"); auto grad = ctx.Input("grad"); - auto param_out = ctx.Output(0); + auto param_out = ctx.Output("param_out"); float lr = ctx.op_.GetAttr("learning_rate"); param_out->mutable_data(ctx.GetPlace()); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index a02e2dc39e8f0d3e31c22a5cafeff111d08aa905..761c6de8d4d2150b30b97b58da95da3d5f33db63 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -18,12 +18,12 @@ namespace paddle { namespace operators { class SigmoidOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); - PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output"); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -39,10 +39,13 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { }; class SigmoidOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("Y")->dims()); } }; @@ -50,9 +53,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); -REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); - +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 7af879b2091e4a7f80a3a64be029394156650c23..b01a9b3f23283471f8846325075719ba0e75ed35 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -28,8 +28,8 @@ template class SigmoidKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input(0); - auto output = context.Output(0); + auto input = context.Input("X"); + auto output = context.Output("Y"); output->mutable_data(context.GetPlace()); // The clipping is used in Paddle's raw implenmention @@ -37,7 +37,7 @@ class SigmoidKernel : public framework::OpKernel { auto Y = EigenVector::Flatten(*output); auto place = context.GetEigenDevice(); - Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); + Y.device(place) = 1. / (1. + (-X).exp()); } }; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 9b6a679642303a2cb34954ce16b4a5811acf0ec2..40c51a64c49bc064f55975ef6ced1d54070f1291 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -18,15 +18,13 @@ namespace paddle { namespace operators { class SoftmaxOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL, - "Only one input is need for softmax"); - PADDLE_ENFORCE_EQ(ctx.Input("X")->dims().size(), 2UL, - "The input of softmax op must be matrix"); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL, - "Only one output is need for softmax"); + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, + "The input of softmax op must be matrix"); ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -43,14 +41,12 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { }; class SoftmaxOpGrad : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, - "Input of SoftmaxOpGrad should be 3, X, Y, YG"); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1UL, - "Output of SoftmaxOpGrad should be 1"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")), "Input(Y@GRAD) should not be null"); PADDLE_ENFORCE(ctx.Input("Y")->dims() == @@ -66,9 +62,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, + ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); -REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL( softmax_grad, ops::SoftmaxGradKernel); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index ea81ec053f8b9029114f7c98d292a778dc50c3e4..29491137e6d8b4bfa2d0d07d48ffed1212a6131f 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -27,7 +24,7 @@ template class CPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.op_.GetAttr("seed")); @@ -39,19 +36,22 @@ class CPUUniformRandomKernel : public framework::OpKernel { std::uniform_real_distribution dist( static_cast(context.op_.GetAttr("min")), static_cast(context.op_.GetAttr("max"))); - for (ssize_t i = 0; i < framework::product(tensor->dims()); ++i) { + ssize_t size = framework::product(tensor->dims()); + for (ssize_t i = 0; i < size; ++i) { data[i] = dist(engine); } } }; class UniformRandomOp : public framework::OperatorWithKernel { - DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel) + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(GetAttr("min") < GetAttr("max"), "uniform_random's min must less then max"); - auto* tensor = ctx.Output(0); + auto* tensor = ctx.Output("Out"); auto dims = GetAttr>("dims"); tensor->Resize(framework::make_ddim(dims)); } @@ -64,7 +64,6 @@ class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "The output tensor of uniform random op"); AddComment(R"DOC(Uniform random operator. - Used to initialize tensor with uniform random generator. )DOC"); AddAttr>("dims", "the dimension of random tensor"); @@ -79,7 +78,7 @@ Used to initialize tensor with uniform random generator. } // namespace operators } // namespace paddle -REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp, - paddle::operators::UniformRandomOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, + paddle::operators::UniformRandomOpMaker); REGISTER_OP_CPU_KERNEL(uniform_random, paddle::operators::CPUUniformRandomKernel); diff --git a/paddle/operators/uniform_random_op.cu b/paddle/operators/uniform_random_op.cu index b35ebe7b630be72a5856ec1d3cc32bfaf097aa8a..1d6709934cbbcf50265eabef87c857654f783ed8 100644 --- a/paddle/operators/uniform_random_op.cu +++ b/paddle/operators/uniform_random_op.cu @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -46,7 +43,7 @@ template class GPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* tensor = context.Output(0); + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.op_.GetAttr("seed")); diff --git a/paddle/parameter/Parameter.cpp b/paddle/parameter/Parameter.cpp index ebe36d49376882fe4c1013e19dcf71f452b3e501..f0311095012d944768d80abe423d4a9bfc0e97f5 100644 --- a/paddle/parameter/Parameter.cpp +++ b/paddle/parameter/Parameter.cpp @@ -48,7 +48,8 @@ Parameter::Parameter(const ParameterConfig& config, bool useGpu, bool doInit) deviceId_(-1), sharedCount_(0), updateCounter_(0), - updated_(false) { + updated_(false), + headerFormat_(PARAM_FORMAT_ORIGINAL) { setID(-1); /* capture uninitialized id */ if (useGpu_ && FLAGS_parallel_nn) { /* gpu environment is specified by device property */ @@ -285,7 +286,7 @@ bool Parameter::save(const std::string& filename) const { bool Parameter::save(std::ostream& s) const { CpuVector vec(*bufs_[PARAMETER_VALUE].get()); Header header; - header.version = kFormatVersion; + header.format = headerFormat_; header.valueSize = sizeof(real); header.size = getSize(); @@ -344,8 +345,9 @@ bool Parameter::load(std::istream& s) { Header header; CHECK(s.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameter " << getName(); - CHECK_EQ(header.version, kFormatVersion) << "Incorrect format version: " - << header.version; + CHECK(isHeaderFormatSupported(header.format)) << "Incorrect format version: " + << header.format; + headerFormat_ = header.format; CHECK_EQ(header.size, getSize()) << "The size (" << header.size << ") in the file does not match the size " << "(" << getSize() << ") of the parameter: " << getName(); diff --git a/paddle/parameter/Parameter.h b/paddle/parameter/Parameter.h index 0bac76f068ec22bec52766b43e331fe109a34188..321f4275d8e68d7d3fbbc19acf0afacf689474e5 100644 --- a/paddle/parameter/Parameter.h +++ b/paddle/parameter/Parameter.h @@ -34,6 +34,20 @@ limitations under the License. */ namespace paddle { +typedef enum { + /// The paddle original basic format + PARAM_FORMAT_ORIGINAL = 0, + + /// See mkldnn_memory_format_t in + /// https://github.com/01org/mkl-dnn/blob/master/include/mkldnn_types.h + /// for a detailed description. + /// 2D weights tensor in the format (output channels, input channels). + PARAM_FORMAT_MKLDNN_OI, + + /// The total format items numbers + PARAM_FORMAT_ITEMS, +} PARAM_FORMAT; + class SparsePrefetchRowCpuMatrix; class Parameter; @@ -51,7 +65,10 @@ public: size_t getSize() const { return config_.size(); } bool isFullSize() const { - return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + if (bufs_[PARAMETER_VALUE]) { + return this->getSize() == bufs_[PARAMETER_VALUE]->getSize(); + } + return false; } inline bool useGpu() const { return useGpu_; } @@ -242,14 +259,30 @@ public: /// Initialize the value to 0 void zeroMem(); - static const int kFormatVersion = 0; /// file header structure struct Header { - int32_t version; // = 0, file format version + int32_t format; // = PARAM_FORMAT uint32_t valueSize; // = sizeof(real) uint64_t size; // = getSize() }; + /** + * @brief Is the header format supported. + */ + static bool isHeaderFormatSupported(int32_t fmt) { + return fmt < PARAM_FORMAT_ITEMS; + } + + /** + * @brief Get the format in header. + */ + int getHeaderFormat() { return headerFormat_; } + + /** + * @brief Set the format in header. + */ + void setHeaderFormat(int32_t fmt) { headerFormat_ = fmt; } + /** * @brief Parameter Update Hook. * @@ -321,6 +354,9 @@ protected: bool updated_; SparseFormat format_; + /// The header format for saving or loading param + int32_t headerFormat_; + std::vector> updaterHooks_; public: diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 4154aad15c39119e2f155cb2c7b5177b5aa78022..120eb1e4af9cef43e76e27d4ad66acfbbd597a36 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,7 +1,7 @@ cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) -nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) +nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) @@ -9,6 +9,7 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags) add_subdirectory(dynload) cc_test(enforce_test SRCS enforce_test.cc DEPS stringpiece) +cc_test(environment_test SRCS environment_test.cc DEPS stringpiece) IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) @@ -16,5 +17,8 @@ ELSE() set(GPU_CTX_DEPS) ENDIF() -cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) +# memcpy deoends on device_context, here add deps individually for +# avoiding cycle dependencies +cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator + system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a928e097787db9deebe1c6eab263190caacac7eb..ad212c5b2c47312743362db4926c80bf056e100d 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/platform/device_context.h" +#include "paddle/memory/memory.h" namespace paddle { namespace platform { @@ -36,6 +37,59 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } #ifndef PADDLE_ONLY_CPU +class EigenCudaStreamDevice : public Eigen::StreamInterface { + public: + EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { + Eigen::initializeDeviceProp(); + } + ~EigenCudaStreamDevice() override {} + + void Reinitialize(const cudaStream_t* cuda_stream, GPUPlace place) { + stream_ = cuda_stream; + place_ = place; + device_prop_ = &Eigen::m_deviceProperties[place.device]; + } + + const cudaStream_t& stream() const override { return *stream_; } + + const cudaDeviceProp& deviceProperties() const override { + return *device_prop_; + } + + void* allocate(size_t num_bytes) const override { + return paddle::memory::Alloc(place_, num_bytes); + } + + void deallocate(void* buffer) const override { + paddle::memory::Free(place_, buffer); + } + + void* scratchpad() const override { + if (scratch_ == NULL) { + scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int)); + } + return scratch_; + } + + unsigned int* semaphore() const override { + if (semaphore_ == NULL) { + char* scratch = + static_cast(scratchpad()) + Eigen::kCudaScratchSize; + semaphore_ = reinterpret_cast(scratch); + PADDLE_ENFORCE( + cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_)); + } + return semaphore_; + } + + private: + GPUPlace place_; + const cudaStream_t* stream_; // not owned; + const cudaDeviceProp* device_prop_; // not owned; + mutable void* scratch_; + mutable unsigned int* semaphore_; +}; + template <> Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); @@ -43,19 +97,9 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); - // TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly - // here will cause segment fault. We must implement a class derived from - // Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id - // later. Please refer to the implementation of class EigenCudaStreamDevice - // in TensorFlow. - // - // We find that CUDA 7 introduces a new option, the per-thread default stream, - // that has two effects. Please refer to https://devblogs.nvidia.com/ - // parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/ - // - // So, we decide to use default stream and add –default-stream per-thread nvcc - // flag. Than, two threads with two CUDADeviceContexts will run parallelly. - eigen_stream_.reset(new Eigen::CudaStreamDevice()); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + eigen_stream_.reset(new EigenCudaStreamDevice()); + eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } @@ -70,17 +114,15 @@ CUDADeviceContext::~CUDADeviceContext() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } - if (curand_generator_) { - PADDLE_ENFORCE(dynload::curandDestroyGenerator(curand_generator_)); - } eigen_stream_.reset(); eigen_device_.reset(); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); } Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - PADDLE_ENFORCE(cudaStreamSynchronize(0)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); } Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { @@ -91,6 +133,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() { if (!cublas_handle_) { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); } return cublas_handle_; } @@ -99,20 +142,12 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { if (!cudnn_handle_) { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); } return cudnn_handle_; } -curandGenerator_t CUDADeviceContext::curand_generator() { - if (!curand_generator_) { - SetDeviceId(place_.device); - PADDLE_ENFORCE(dynload::curandCreateGenerator(&curand_generator_, - CURAND_RNG_PSEUDO_DEFAULT)); - PADDLE_ENFORCE( - dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); - } - return curand_generator_; -} +cudaStream_t CUDADeviceContext::stream() { return stream_; } #endif // PADDLE_ONLY_CPU diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 08b5b2cff900cc4239a615fe7d7f6b5faa13510b..11528e1194e4516891034fa8febdac3ba6eed204 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -17,7 +17,6 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" -#include "paddle/platform/dynload/curand.h" #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -40,7 +39,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - explicit CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace place); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -52,10 +51,11 @@ class CPUDeviceContext : public DeviceContext { }; #ifndef PADDLE_ONLY_CPU +class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + explicit CUDADeviceContext(GPUPlace place); virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -74,24 +74,20 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle(); - /*! \brief Return curand handle in the device context. */ - curandGenerator_t curand_generator(); + /*! \brief Return cuda stream in the device context. */ + cudaStream_t stream(); // clang-format on private: GPUPlace place_; - private: std::unique_ptr eigen_device_; - std::unique_ptr eigen_stream_; - - private: - uint64_t seed_; + std::unique_ptr eigen_stream_; // clang-format off - cudnnHandle_t cudnn_handle_ = nullptr; - cublasHandle_t cublas_handle_ = nullptr; - curandGenerator_t curand_generator_ = nullptr; + cudaStream_t stream_{nullptr}; + cudnnHandle_t cudnn_handle_{nullptr}; + cublasHandle_t cublas_handle_{nullptr}; // clang-format on }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 65345c433c0a328e7f89038a39312edba35eb8c7..5883a55272f0f24c94d48bc43c62ddb7bef15465 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -43,8 +43,7 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); ASSERT_NE(nullptr, cublas_handle); - curandGenerator_t curand_handle = device_context->curand_generator(); - ASSERT_NE(nullptr, curand_handle); + ASSERT_NE(nullptr, device_context->stream()); delete device_context; } } diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index aad8097dbb33cbf6c0f2b4b3efb1376fbe96bc74..9d8343c0b5e200b390ccda760f09816959952e9d 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,12 +62,12 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSgemv); \ - __macro(cublasDgemv); \ - __macro(cublasSgemm); \ - __macro(cublasDgemm); \ - __macro(cublasSgeam); \ - __macro(cublasDgeam); \ + __macro(cublasSgemv_v2); \ + __macro(cublasDgemv_v2); \ + __macro(cublasSgemm_v2); \ + __macro(cublasDgemm_v2); \ + __macro(cublasSgeam_v2); \ + __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 337a059fb1494d500be0fd2437e59c863ae1563c..81448897e95eb05f4ce7de8683d98e05bade77cb 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -14,14 +14,21 @@ limitations under the License. */ #pragma once -#include +#include // for dladdr +#include // for backtrace #include +#include #include #include #include + #include "paddle/string/printf.h" #include "paddle/string/to_string.h" +#ifdef __GNUC__ +#include // for __cxa_demangle +#endif + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" @@ -39,6 +46,19 @@ limitations under the License. */ namespace paddle { namespace platform { +namespace { +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif +} + struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; @@ -48,15 +68,29 @@ struct EnforceNotMet : public std::exception { std::rethrow_exception(exp_); } catch (const std::exception& exp) { std::ostringstream sout; + sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; - sout << "Call Stacks: " << std::endl; + sout << "PaddlePaddle Call Stacks: " << std::endl; + void* call_stack[TRACE_STACK_LIMIT]; - int sz = backtrace(call_stack, TRACE_STACK_LIMIT); - auto line = backtrace_symbols(call_stack, sz); - for (int i = 0; i < sz; ++i) { - sout << line[i] << std::endl; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + + Dl_info info; + for (int i = 0; i < size; ++i) { + if (dladdr(call_stack[i], &info)) { + auto demangled = demangle(info.dli_sname); + auto addr_offset = static_cast(call_stack[i]) - + static_cast(info.dli_saddr); + sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, + 2 + sizeof(void*) * 2, call_stack[i], + demangled, addr_offset); + } else { + sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, + call_stack[i]); + } } - free(line); + free(symbols); err_str_ = sout.str(); } } @@ -170,7 +204,7 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "enforce a == b failed, 1 != 2" with detailed stack infomation. + * "enforce a == b failed, 1 != 2" with detailed stack information. * * extra messages is also supported, for example: * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) diff --git a/paddle/platform/environment.h b/paddle/platform/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..4edcce932edc61453cef74f2c4ee0f72496b3677 --- /dev/null +++ b/paddle/platform/environment.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/platform/enforce.h" +#include "paddle/string/piece.h" + +extern char** environ; // for environment variables + +namespace paddle { +namespace platform { + +inline void SetEnvVariable(const std::string& name, const std::string& value) { + PADDLE_ENFORCE_NE(setenv(name.c_str(), value.c_str(), 1), -1, + "Failed to set environment variable %s=%s", name, value); +} + +inline void UnsetEnvVariable(const std::string& name) { + PADDLE_ENFORCE_NE(unsetenv(name.c_str()), -1, + "Failed to unset environment variable %s", name); +} + +inline bool IsEnvVarDefined(const std::string& name) { + return std::getenv(name.c_str()) != nullptr; +} + +inline std::string GetEnvValue(const std::string& name) { + PADDLE_ENFORCE(IsEnvVarDefined(name), + "Tried to access undefined environment variable %s", name); + return std::getenv(name.c_str()); +} + +inline std::vector GetAllEnvVariables() { + std::vector vars; + for (auto var = environ; *var != nullptr; ++var) { + auto tail = string::Index(*var, "="); + auto name = string::SubStr(*var, 0, tail).ToString(); + vars.push_back(name); + } + return vars; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/environment_test.cc b/paddle/platform/environment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f136527215d6a676cfa1a3b08f09dfd3ab24a90 --- /dev/null +++ b/paddle/platform/environment_test.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/environment.h" + +#include "glog/logging.h" +#include "gtest/gtest.h" + +TEST(ENVIRONMENT, ACCESS) { + namespace platform = paddle::platform; + namespace string = paddle::string; + + platform::SetEnvVariable("PADDLE_USE_ENV", "TRUE"); + + EXPECT_TRUE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + EXPECT_EQ(platform::GetEnvValue("PADDLE_USE_ENV"), "TRUE"); + + platform::UnsetEnvVariable("PADDLE_USE_ENV"); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV")); + + platform::SetEnvVariable("PADDLE_USE_ENV1", "Hello "); + platform::SetEnvVariable("PADDLE_USE_ENV2", "World, "); + platform::SetEnvVariable("PADDLE_USE_ENV3", "PaddlePaddle!"); + + std::string env_info; + auto vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_TRUE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + platform::UnsetEnvVariable("PADDLE_USE_ENV1"); + platform::UnsetEnvVariable("PADDLE_USE_ENV2"); + platform::UnsetEnvVariable("PADDLE_USE_ENV3"); + + env_info.clear(); + vars = platform::GetAllEnvVariables(); + for_each(vars.begin(), vars.end(), [&](const std::string& var) { + env_info += platform::GetEnvValue(var); + }); + + EXPECT_FALSE(string::Contains(env_info, "Hello World, PaddlePaddle!")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV1")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV2")); + EXPECT_FALSE(platform::IsEnvVarDefined("PADDLE_USE_ENV3")); +} diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index edeb3ecd7bf8b87333813eee5b40f71030f6609f..be381a4e26cf0eb41f5b3de88bd03ad8901683cc 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/platform/gpu_info.h" + #include "gflags/gflags.h" + #include "paddle/platform/enforce.h" +#include "paddle/platform/environment.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, "Default use 95% of GPU memory for PaddlePaddle," @@ -70,6 +73,13 @@ size_t GpuMaxChunkSize() { GpuMemoryUsage(available, total); + if (IsEnvVarDefined(kEnvFractionGpuMemoryToUse)) { + auto val = std::stod(GetEnvValue(kEnvFractionGpuMemoryToUse)); + PADDLE_ENFORCE_GT(val, 0.0); + PADDLE_ENFORCE_LE(val, 1.0); + FLAGS_fraction_of_gpu_memory_to_use = val; + } + // Reserving the rest memory for page tables, etc. size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total; diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index d3a5f5f13fdd3dd59eb43465da4a64b0d8d95e5b..ed2420b8740e583d307f6836a70fe7e1c780e28b 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -18,10 +18,15 @@ limitations under the License. */ #include #include +#include namespace paddle { namespace platform { +//! Environment variable: fraction of GPU memory to use on each device. +const std::string kEnvFractionGpuMemoryToUse = + "PADDLE_FRACTION_GPU_MEMORY_TO_USE"; + //! Get the total number of GPU devices in system. int GetDeviceCount(); diff --git a/paddle/pserver/ParameterClient2.cpp b/paddle/pserver/ParameterClient2.cpp index f7e391f76324a09c203dfbbb449feb050caa8fb4..54063a809a4f9e558f8d364f5c437f2b6d98925b 100644 --- a/paddle/pserver/ParameterClient2.cpp +++ b/paddle/pserver/ParameterClient2.cpp @@ -65,7 +65,6 @@ void ParameterClient2::initThreads() { LOG(INFO) << "parallel_thread_num dosent need to set"; } syncThreadPool_.reset(new SyncThreadPool(threadNum_)); - startThreads(); } @@ -224,6 +223,14 @@ void ParameterClient2::prepareSendData( request.set_cost(cost); request.set_batch_status(batchStatus); CHECK_EQ(request.blocks_size(), 0); + VLOG(10) << "request: trainer_id: " << request.trainer_id() + << " update_mode" << request.update_mode() + << " send_back_parameter: " << request.send_back_parameter() + << " send_back_parameter_type: " + << request.send_back_parameter_type() + << " num_samples: " << request.num_samples() + << " cost: " << request.cost() + << " batch_status: " << request.batch_status(); } for (const auto& segments : parameterSegments) { const auto it = parameterMap_.find(segments.id); @@ -251,11 +258,17 @@ void ParameterClient2::prepareSendData( CHECK(sendMat != nullptr) << "sendMat is nullptr"; syncThreadPool_->exec([&](int tid, size_t numThreads) { + std::lock_guard guard(sparseAutoGrowthMutex_); const auto& localIndices = prefetchMat->getLocalIndices(); /// num of sparse rows size_t nLocalBlocks = localIndices.size(); uint64_t beginDim = 0; uint64_t endDim = 0; + + // FIXME(typhoonzero): let it resize first + prefetchMat->getLocalRow(nLocalBlocks + 1); + sendMat->getLocalRow(nLocalBlocks + 1); + for (size_t row = 0; row < nLocalBlocks; ++row) { int64_t blockId = localIndices[row]; // local row -> sparse row int serverId = std::abs((blockId + nameHash) % serviceNum_); @@ -275,7 +288,6 @@ void ParameterClient2::prepareSendData( block->set_begin_pos(row * blockSize); /// block len block->set_block_size(endDim - beginDim); - if (sendingPara) { sendJob->parallelInputIovs[serverId].push_back( {sendMat->getLocalRow(row), sizeof(real) * (size_t)blockSize}); diff --git a/paddle/pserver/ParameterClient2.h b/paddle/pserver/ParameterClient2.h index 89b3ddd502151e537b81bdbb09f171dd6e13ba26..29b9eeacddf2945dd22b7b17fc87c7c74b868896 100644 --- a/paddle/pserver/ParameterClient2.h +++ b/paddle/pserver/ParameterClient2.h @@ -583,6 +583,7 @@ protected: #ifndef PADDLE_DISABLE_TIMER uint64_t forwardbackwordTime_; #endif + std::mutex sparseAutoGrowthMutex_; /// map id to parameter used for decoding protobuf data std::unordered_map parameterMap_; diff --git a/paddle/pserver/ParameterServer2.cpp b/paddle/pserver/ParameterServer2.cpp index d7c1d4f788f44c6bfcec040ba24bdc454348c911..54f5c4c0fb4994871edc7a1e52237c9f903ce63b 100644 --- a/paddle/pserver/ParameterServer2.cpp +++ b/paddle/pserver/ParameterServer2.cpp @@ -1032,8 +1032,8 @@ void ParameterServer2::loadValueVector(const LoadValueRequest& request, Parameter::Header header; CHECK(fs.read(reinterpret_cast(&header), sizeof(header))) << "Fail to read parameters in pserver"; - CHECK_EQ(header.version, Parameter::kFormatVersion) - << "Incorrect format version: " << header.version; + CHECK(Parameter::isHeaderFormatSupported(header.format)) + << "Incorrect format version: " << header.format; CHECK_EQ(header.size, (size_t)size_) << "The size (" << header.size << ") in the file does not match the size " << "(" << size_ << ") of the pserver: " << serverId_; @@ -1063,7 +1063,8 @@ void ParameterServer2::saveValueVector(const SaveValueRequest& request, CpuVector& vec = vectors_[PARAMETER_APPLY] ? *vectors_[PARAMETER_APPLY] : *vectors_[PARAMETER_VALUE]; Parameter::Header header; - header.version = Parameter::kFormatVersion; + // TODO(TJ): save param headerFormat_ + header.format = PARAM_FORMAT_ORIGINAL; header.valueSize = sizeof(real); header.size = size_; diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 2f0205b7702b6d73b5348430f39166ec78f6c143..2941662f349baf57d1fe8188e88ce21d5de07750 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -82,10 +82,6 @@ EOF fi -# To build documentation, we need to run cmake again after installing -# PaddlePaddle. This awkwardness is due to -# https://github.com/PaddlePaddle/Paddle/issues/1854. It also -# describes a solution. if [[ ${WITH_DOC:-OFF} == "ON" ]]; then cat <> /paddle/build/Dockerfile </dev/null) - BASEDIR=$(dirname "$0") - pip install ${BASEDIR}/../opt/paddle/share/wheels/*-${PYTHON_PADDLE_VERSION}-*.whl - if [ $? -ne 0 ]; then - echo "pip install wheels failed. " - echo "Please use 'sudo paddle' at the first time you use PaddlePaddle" - echo "PaddlePaddle will install some python dependencies automatically." - exit 1 - fi - echo "Python dependencies are installed." -fi case "$1" in "train") - ${DEBUGGER} $MYDIR/../opt/paddle/bin/paddle_trainer ${@:2} + ${DEBUGGER} $PADDLE_BIN_PATH/paddle_trainer ${@:2} ;; "merge_model") - ${DEBUGGER} $MYDIR/../opt/paddle/bin/paddle_merge_model ${@:2} + ${DEBUGGER} $PADDLE_BIN_PATH/paddle_merge_model ${@:2} ;; "pserver") - ${DEBUGGER} $MYDIR/../opt/paddle/bin/paddle_pserver_main ${@:2} + ${DEBUGGER} $PADDLE_BIN_PATH/paddle_pserver_main ${@:2} ;; "dump_config") python -m paddle.utils.dump_config ${@:2} @@ -127,7 +110,7 @@ case "$1" in python -m paddle.utils.make_model_diagram ${@:2} ;; "usage") - $MYDIR/../opt/paddle/bin/paddle_usage ${@:2} + $PADDLE_BIN_PATH/paddle_usage ${@:2} ;; "version") version diff --git a/paddle/trainer/TrainerConfigHelper.cpp b/paddle/trainer/TrainerConfigHelper.cpp index eba40862b926cfe863c569e73a6a3ceabcf1f3b4..a0a365aa0bb0ac26939a02c1cd626d0c17c6a9fe 100644 --- a/paddle/trainer/TrainerConfigHelper.cpp +++ b/paddle/trainer/TrainerConfigHelper.cpp @@ -29,7 +29,6 @@ DECLARE_bool(with_gpu); DECLARE_bool(parallel_nn); DECLARE_string(config_args); DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); const char *kConfigParserModuleName = "paddle.trainer.config_parser"; const char *kConfigParserFuncName = "parse_config_and_serialize"; @@ -47,7 +46,6 @@ TrainerConfigHelper::TrainerConfigHelper(const std::string &configFilePath) << ",with_cost=" << FLAGS_with_cost << ",use_gpu=" << FLAGS_use_gpu << ",parallel_nn=" << FLAGS_parallel_nn << ",use_mkldnn=" << FLAGS_use_mkldnn - << ",use_mkldnn_wgt=" << FLAGS_use_mkldnn_wgt << ",cudnn_version=" << hl_get_cudnn_lib_version(); if (!FLAGS_config_args.empty()) { configArgs << "," << FLAGS_config_args; diff --git a/paddle/utils/Flags.cpp b/paddle/utils/Flags.cpp index 600c83a8487191895de635dd8433f6c44e86ce77..ab1c181c62cdbee8cc5f804ec9aaf63ac5464ad6 100644 --- a/paddle/utils/Flags.cpp +++ b/paddle/utils/Flags.cpp @@ -27,7 +27,6 @@ DEFINE_bool(use_mkldnn, false, "Default still keep use CPU training"); DEFINE_bool(use_mkldnn, false, "Only support CPU training"); #endif -DEFINE_bool(use_mkldnn_wgt, false, "Init weight from CPU weight"); DEFINE_bool(parallel_nn, false, "Whether to use multi-threads to calculate one neural network." diff --git a/paddle/utils/Flags.h b/paddle/utils/Flags.h index 0aca4c0ee036ee8490c0ceca7279df876dc21947..1832bb515ec85df3d7733e01b063a01ad6a3b282 100644 --- a/paddle/utils/Flags.h +++ b/paddle/utils/Flags.h @@ -41,4 +41,3 @@ DECLARE_string(predict_file); DECLARE_bool(prev_batch_state); DECLARE_string(init_model_path); DECLARE_bool(use_mkldnn); -DECLARE_bool(use_mkldnn_wgt); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 16c519d45aa62694201379b8da1ca54d8a07ee9a..7bd6d59b0096c23bb791b9b50702130057628879 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -21,6 +21,18 @@ if(WITH_GOLANG) add_dependencies(copy_paddle_master paddle_master) endif(WITH_GOLANG) +set(MKL_SHARED_LIBS "") +set(MKL_DEPENDS "") +if(WITH_MKLML) + list(APPEND MKL_SHARED_LIBS ${MKLML_LIB} ${MKLML_IOMP_LIB}) + list(APPEND MKL_DEPENDS mklml) +endif() + +if(WITH_MKLDNN) + list(APPEND MKL_SHARED_LIBS "${MKLDNN_LIB}" "${MKLDNN_LIB}.0") + list(APPEND MKL_DEPENDS mkldnn) +endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ${CMAKE_CURRENT_BINARY_DIR}/setup.py) @@ -38,8 +50,11 @@ add_custom_command(OUTPUT ${PADDLE_PYTHON_BUILD_DIR}/.timestamp COMMAND ${CMAKE_COMMAND} -E copy_directory ${PADDLE_PYTHON_BUILD_DIR}/lib* ${PADDLE_PYTHON_BUILD_DIR}/lib-python DEPENDS gen_proto_py copy_paddle_pybind framework_py_proto ${PY_FILES} ${external_project_dependencies} ${COPY_PADDLE_MASTER}) -add_custom_target(paddle_python ALL DEPENDS - ${PADDLE_PYTHON_BUILD_DIR}/.timestamp paddle_pserver_main paddle_trainer paddle_merge_model python_api_wheel) +set(paddle_python_deps ${PADDLE_PYTHON_BUILD_DIR}/.timestamp paddle_pserver_main paddle_trainer paddle_merge_model ${MKL_DEPENDS}) +if(WITH_SWIG_PY) + list(APPEND paddle_python_deps python_api_wheel) +endif() +add_custom_target(paddle_python ALL DEPENDS ${paddle_python_deps}) set(PADDLE_PYTHON_PACKAGE_DIR ${CMAKE_CURRENT_BINARY_DIR}/dist/) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index a24299787bfd6d9d1a9b01ba3117c3ec863f9552..7707ece819c9e684e13730e21c8d8c64649e2710 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -2247,6 +2247,20 @@ class ClipLayer(LayerBase): self.config.inputs[0].clip_conf.max = max +@config_layer('scale_shift') +class ScaleShiftLayer(LayerBase): + def __init__(self, name, inputs, bias=True, **xargs): + super(ScaleShiftLayer, self).__init__( + name, 'scale_shift', 0, inputs=inputs, **xargs) + config_assert( + len(self.inputs) == 1, + 'ScaleShiftLayer must have one and only one input.') + input_layer = self.get_input_layer(0) + self.set_layer_size(input_layer.size) + self.create_input_parameter(0, 1, [1, 1]) + self.create_bias_parameter(bias, 1) + + # key: cost type # value: cost class g_cost_map = {} diff --git a/python/paddle/trainer_config_helpers/evaluators.py b/python/paddle/trainer_config_helpers/evaluators.py index 44d52edfa7bae49bea196eba9387391b171840d8..57979db4de08989ab583b0ab41589c09789a0921 100644 --- a/python/paddle/trainer_config_helpers/evaluators.py +++ b/python/paddle/trainer_config_helpers/evaluators.py @@ -298,8 +298,8 @@ def pnpair_evaluator( input, label, info, - name=None, - weight=None, ): + weight=None, + name=None, ): """ Positive-negative pair rate Evaluator which adapts to rank task like learning to rank. This evaluator must contain at least three layers. @@ -308,27 +308,31 @@ def pnpair_evaluator( .. code-block:: python - eval = pnpair_evaluator(input, info, label) + eval = pnpair_evaluator(input, label, info) - :param name: Evaluator name. - :type name: None|basestring :param input: Input Layer name. The output prediction of network. :type input: LayerOutput :param label: Label layer name. :type label: LayerOutput - :param info: Label layer name. (TODO, explaination) + :param info: Info layer name. (TODO, explaination) :type info: LayerOutput :param weight: Weight Layer name. It should be a matrix with size [sample_num, 1]. (TODO, explaination) :type weight: LayerOutput + :param name: Evaluator name. + :type name: None|basestring """ + if not isinstance(input, list): + input = [input] + if label: + input.append(label) + if info: + input.append(info) evaluator_base( - name=name, - type="pnpair", input=input, - label=label, - info=info, - weight=weight) + type="pnpair", + weight=weight, + name=name, ) @evaluator(EvaluatorAttribute.FOR_CLASSIFICATION) @@ -429,12 +433,12 @@ def chunk_evaluator( .. code-block:: text - Scheme Description + Scheme Description plain Use the same label for the whole chunk. - IOB Two labels for chunk type X, B-X for chunk begining and I-X for chunk inside. + IOB Two labels for chunk type X, B-X for chunk begining and I-X for chunk inside. IOE Two labels for chunk type X, E-X for chunk ending and I-X for chunk inside. - IOBES Four labels for chunk type X, B-X for chunk begining, I-X for chunk inside, E-X for chunk end and S-X for single word chunk. - + IOBES Four labels for chunk type X, B-X for chunk begining, I-X for chunk inside, E-X for chunk end and S-X for single word chunk. + To make it clear, let's illustrate by an NER example. Assuming that there are three named entity types including ORG, PER and LOC which are called 'chunk type' here, if 'IOB' scheme were used, the label set will be extended to a set including B-ORG, I-ORG, B-PER, I-PER, B-LOC, I-LOC and O, @@ -451,7 +455,7 @@ def chunk_evaluator( tagType = label % numTagType chunkType = label / numTagType otherChunkType = numChunkTypes - + The following table shows the mapping rule between tagType and tag type in each scheme. .. code-block:: text @@ -475,7 +479,7 @@ def chunk_evaluator( O 6 In this example, chunkType has three values: 0 for ORG, 1 for PER, 2 for LOC, because the scheme is - "IOB" so tagType has two values: 0 for B and 1 for I. + "IOB" so tagType has two values: 0 for B and 1 for I. Here we will use I-LOC to explain the above mapping rules in detail. For I-LOC, the label id is 5, so we can get tagType=1 and chunkType=2, which means I-LOC is a part of NER chunk LOC and the tag is I. @@ -486,7 +490,7 @@ def chunk_evaluator( eval = chunk_evaluator(input, label, chunk_scheme, num_chunk_types) - + :param input: The input layers. :type input: LayerOutput :param label: An input layer containing the ground truth label. diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 2b01b6ad4d79031aa16a583937eb8444d91cbf3a..b027f84b5d576103b6e03ef6709a6c1f335aabe2 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -133,6 +133,7 @@ __all__ = [ 'clip_layer', 'slice_projection', 'kmax_sequence_score_layer', + 'scale_shift_layer', ] @@ -231,6 +232,7 @@ class LayerType(object): CLIP_LAYER = 'clip' KMAX_SEQ_SCORE = 'kmax_seq_score' + SCALE_SHIFT_LAYER = 'scale_shift' @staticmethod def is_layer_type(type_name): @@ -6238,3 +6240,43 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1): return LayerOutput( name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size) + + +@wrap_name_default("scale_shift") +@wrap_param_attr_default() +@wrap_bias_attr_default() +def scale_shift_layer(input, name=None, param_attr=None, bias_attr=None): + """ + A layer applies a linear transformation to each element in each row of + the input matrix. For each element, the layer first re-scale it and then + adds a bias to it. + + This layer is very like the SlopeInterceptLayer, except the scale and + bias are trainable. + + .. math:: + + y = w * x + b + + .. code-block:: python + + scale_shift = scale_shift_layer(input=input_layer, bias_attr=False) + + :param name: The Layer Name. + :type name: basestring + :param input: The input layer. + :type input: LayerOutput. + :param param_attr: The parameter attribute of scaling. + :type param_attr: ParameterAttribute + :param bias_attr: The parameter attribute of shifting. + :type bias_attr: ParameterAttribute + :return: LayerOutput object. + :rtype: LayerOutput + """ + Layer( + name=name, + type=LayerType.SCALE_SHIFT_LAYER, + inputs=Input(input.name, **param_attr.attr), + bias=ParamAttr.to_bias(bias_attr)) + return LayerOutput( + name, LayerType.SCALE_SHIFT_LAYER, parents=[input], size=input.size) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 130e6332a7cf58d0fe54dddcaf05eedd161fd112..76e89fa7058e4d6b8cf8056c4419bb739ebbfc00 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -8,6 +8,7 @@ test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer -test_kmax_seq_socre_layer test_seq_select_layers test_cross_entropy_over_beam) +test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer +test_cross_entropy_over_beam) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..35ade126a2586a8e3eee6f0ac3c7e49523c8f5c5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_scale_shift_layer.protostr @@ -0,0 +1,72 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 100 + active_type: "" +} +layers { + name: "__scale_shift_0__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_0__.w0" + } +} +layers { + name: "__scale_shift_1__" + type: "scale_shift" + size: 100 + active_type: "" + inputs { + input_layer_name: "data" + input_parameter_name: "___scale_shift_1__.w0" + } + bias_parameter_name: "___scale_shift_1__.wbias" +} +parameters { + name: "___scale_shift_0__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_1__.w0" + size: 1 + initial_mean: 0.0 + initial_std: 1.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: true +} +parameters { + name: "___scale_shift_1__.wbias" + size: 1 + initial_mean: 0.0 + initial_std: 0.0 + dims: 1 + dims: 1 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "data" +output_layer_names: "__scale_shift_0__" +output_layer_names: "__scale_shift_1__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__scale_shift_0__" + layer_names: "__scale_shift_1__" + input_layer_names: "data" + output_layer_names: "__scale_shift_0__" + output_layer_names: "__scale_shift_1__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..dd589116fa9932144ca066d3fa4c929d1433a7f1 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_scale_shift_layer.py @@ -0,0 +1,9 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='data', size=100) + +scale = scale_shift_layer(input=data, bias_attr=False) + +scale_shift = scale_shift_layer(input=data) + +outputs(scale, scale_shift) diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 7fd8b55a5d167294d3270c79f7b64da03443afd3..6ac656321e72f5b0c91008091753ee50ac8200a6 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -1,7 +1,5 @@ import paddle.v2.framework.core as core -import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 -import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 -import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2 +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 def get_all_op_protos(): @@ -12,22 +10,26 @@ def get_all_op_protos(): protostrs = core.get_all_op_protos() ret_values = [] for pbstr in protostrs: - op_proto = op_proto_pb2.OpProto.FromString(str(pbstr)) + op_proto = framework_pb2.OpProto.FromString(str(pbstr)) ret_values.append(op_proto) return ret_values +def is_str(s): + return isinstance(s, str) or isinstance(s, unicode) + + class OpDescCreationMethod(object): """ A Functor object to convert user input(use key word args) to OpDesc based on OpProto. - + :param op_proto: The OpProto object. :type op_proto: op_proto_pb2.OpProto """ def __init__(self, op_proto): - if not isinstance(op_proto, op_proto_pb2.OpProto): + if not isinstance(op_proto, framework_pb2.OpProto): raise TypeError("Argument should be OpProto") self.__op_proto__ = op_proto @@ -39,26 +41,34 @@ class OpDescCreationMethod(object): """ if len(args) != 0: raise ValueError("Only keyword arguments is supported by Paddle") - op_desc = op_desc_pb2.OpDesc() - - # Inputs - ipts, ipt_format, _ = OpDescCreationMethod.extract_input_or_output( - "input", kwargs, self.__op_proto__.inputs) - op_desc.inputs.extend(ipts) - if ipt_format is not None: - op_desc.attrs.extend([ipt_format]) - - # Outputs - outs, out_format, tmp_index = OpDescCreationMethod.extract_input_or_output( - "output", kwargs, self.__op_proto__.outputs) - op_desc.outputs.extend(outs) - if out_format is not None: - op_desc.attrs.extend([out_format]) - if len(tmp_index) != 0: - tmp_index_attr = op_desc.attrs.add() - tmp_index_attr.type = attribute_pb2.INTS - tmp_index_attr.name = "temporary_index" - tmp_index_attr.ints.extend(tmp_index) + op_desc = framework_pb2.OpDesc() + + for input_parameter in self.__op_proto__.inputs: + input_arguments = kwargs.get(input_parameter.name, []) + if is_str(input_arguments): + input_arguments = [input_arguments] + + if not input_parameter.duplicable and len(input_arguments) > 1: + raise ValueError("Input %s only accepts one input, but give %d" + % (input_parameter.name, len(input_arguments))) + + ipt = op_desc.inputs.add() + ipt.parameter = input_parameter.name + ipt.arguments.extend(input_arguments) + + for output_parameter in self.__op_proto__.outputs: + output_arguments = kwargs.get(output_parameter.name, []) + if is_str(output_arguments): + output_arguments = [output_arguments] + + if not output_parameter.duplicable and len(output_arguments) > 1: + raise ValueError( + "Output %s only accepts one output, but give %d" % + (output_parameter.name, len(output_arguments))) + + out = op_desc.outputs.add() + out.parameter = output_parameter.name + out.arguments.extend(output_arguments) # Types op_desc.type = self.__op_proto__.type @@ -72,17 +82,17 @@ class OpDescCreationMethod(object): new_attr = op_desc.attrs.add() new_attr.name = attr.name new_attr.type = attr.type - if attr.type == attribute_pb2.INT: + if attr.type == framework_pb2.INT: new_attr.i = user_defined_attr - elif attr.type == attribute_pb2.FLOAT: + elif attr.type == framework_pb2.FLOAT: new_attr.f = user_defined_attr - elif attr.type == attribute_pb2.STRING: + elif attr.type == framework_pb2.STRING: new_attr.s = user_defined_attr - elif attr.type == attribute_pb2.INTS: + elif attr.type == framework_pb2.INTS: new_attr.ints.extend(user_defined_attr) - elif attr.type == attribute_pb2.FLOATS: + elif attr.type == framework_pb2.FLOATS: new_attr.floats.extend(user_defined_attr) - elif attr.type == attribute_pb2.STRINGS: + elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) else: raise NotImplementedError("Not support attribute type " + @@ -90,50 +100,6 @@ class OpDescCreationMethod(object): return op_desc - @staticmethod - def extract_input_or_output(in_out, kwargs, meta): - """ - Extract input variable names or output variable names from key-word - arguments, which base on VarProtos. - - :param in_out: "input" or "output" - :param kwargs: key-word arguments that user inputted. - :param meta: a list of VarProto - :return: The three object will be return. The variable names. The - input_format or output_format attribute(None if the input or output is - not multiple). The temporary variable index list. - """ - multiple = OpDescCreationMethod.any_is_true((m.multiple for m in meta)) - tmp_index = [] - retv = [] - if multiple: - var_format = op_desc_pb2.AttrDesc() - var_format.type = attribute_pb2.INTS - var_format.name = "%s_format" % in_out - var_format.ints.append(0) - - for var in meta: - var_name = var.name - - if var.temporary: - var_name = [core.var_names.temp()] - tmp_index.append(len(retv)) - else: - var_name = kwargs.get(var_name, []) - if not isinstance(var_name, list): - var_name = [var_name] - retv.extend(var_name) - var_format.ints.append(len(var_name) + var_format.ints[-1]) - return retv, var_format, tmp_index - else: - for var in meta: - if var.temporary: - retv.append(kwargs.get(var.name, core.var_names.temp())) - tmp_index.append(len(retv)) - else: - retv.append(kwargs.get(var.name, core.var_names.empty())) - return retv, None, tmp_index - @staticmethod def any_is_true(generator): """ @@ -146,13 +112,12 @@ class OpDescCreationMethod(object): class OpInfo(object): - def __init__(self, name, method, inputs, outputs, attrs, no_temp_outputs): + def __init__(self, name, method, inputs, outputs, attrs): self.name = name self.method = method self.inputs = inputs self.outputs = outputs self.attrs = attrs - self.no_temp_outputs = no_temp_outputs def create_op_creation_method(op_proto): @@ -170,10 +135,7 @@ def create_op_creation_method(op_proto): name=op_proto.type, inputs=[var.name for var in op_proto.inputs], outputs=[var.name for var in op_proto.outputs], - attrs=[attr.name for attr in op_proto.attrs], - no_temp_outputs=[ - var.name for var in op_proto.outputs if not var.temporary - ]) + attrs=[attr.name for attr in op_proto.attrs]) class OperatorFactory(object): @@ -214,8 +176,27 @@ class OperatorFactory(object): def get_op_attr_names(self, type): return self.get_op_info(type).attrs - def get_op_no_temp_output_names(self, type): - return self.get_op_info(type).no_temp_outputs + +class __RecurrentOp__(object): + __proto__ = None + type = 'recurrent_op' + + def __init__(self): + # cache recurrent_op's proto + if self.__proto__ is None: + for op_proto in get_all_op_protos(): + if op_proto.type == self.type: + self.__proto__ = op_proto + + def __call__(self, *args, **kwargs): + if self.type not in args and 'type' not in kwargs: + kwargs['type'] = self.type + # create proto + create_method = OpDescCreationMethod(self.__proto__) + proto = create_method(*args, **kwargs) + # create rnnop + return core.RecurrentOp.create(proto.SerializeToString()) Operator = OperatorFactory() # Default global factory +RecurrentOp = __RecurrentOp__() diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 96fad9b42e04a88fdcbda093683b57451b2a3e41..b07a65f4d1fed12d82c638ee59f9de72379cfcbe 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -22,6 +22,8 @@ py_test(test_rowwise_add_op SRCS test_rowwise_add_op.py) py_test(test_default_scope_funcs SRCS test_default_scope_funcs.py) py_test(test_operator SRCS test_operator.py) -# py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) +py_test(test_gaussian_random_op SRCS test_gaussian_random_op.py) py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) +py_test(test_sgd_op SRCS test_sgd_op.py) +py_test(test_gradient_checker SRCS test_gradient_checker.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 015e832e82560bb8b3518cbdf605c705d77cdd99..8b8e2f444be1169c23784321721c5d8154541fcf 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -1,6 +1,7 @@ import unittest import numpy +import itertools import paddle.v2.framework.core as core from paddle.v2.framework.op import Operator @@ -8,6 +9,7 @@ __all__ = ['get_numeric_gradient'] def create_op(op_type): + # TODO need to set attrs kwargs = dict() for in_name in Operator.get_op_input_names(op_type): kwargs[in_name] = in_name @@ -53,17 +55,19 @@ def get_numeric_gradient(op, tensor.set(input_values[var_name], core.CPUPlace()) # Create all output variable in local_scope - for output in op.outputs(): - if local_scope.find_var(output) is None: - local_scope.new_var(output).get_tensor() - + opts = op.outputs() + for key in opts: + for output in opts[key]: + if local_scope.find_var(output) is None: + local_scope.new_var(output).get_tensor() op.infer_shape(local_scope) # allocate output memory - for output in op.outputs(): - local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace()) + for key in opts: + for output in opts[key]: + local_scope.find_var(output).get_tensor().alloc_float(core.CPUPlace( + )) - # TODO(yuyang18): Only CPU is support now. cpu_ctx = core.DeviceContext.create(core.CPUPlace()) def get_output(): @@ -106,12 +110,110 @@ def get_numeric_gradient(op, class GradientChecker(unittest.TestCase): - def assert_is_close(self, numeric_grads, scope, max_relative_error, - msg_prefix): - for name in numeric_grads: - b = numpy.array(scope.find_var(grad_var_name(name)).get_tensor()) - a = numeric_grads[name] - + def __get_gradient(self, forward_op, backward_op, input_value, grad_names, + place): + """Get the input gradients after running forward and backward operators + on the given places. + + :param forward_op: forward operator + :type forward_op: Operator + :param backward_op: backward operator + :type backward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :param grad_names: the names of returned input gradients. + :type input_value: a list of string + :param place: the device type. + :type place: CPUPlace or GPUPlace + :return: the input grdients of given grad_names. + :rtype: a list of numpy.array + """ + scope = core.Scope() + ctx = core.DeviceContext.create(place) + + inputs = forward_op.inputs() + in_names = [item for k in inputs for item in inputs[k]] + outputs = forward_op.outputs() + out_names = [item for k in outputs for item in outputs[k]] + + # create input var and set value + for name, value in input_value.iteritems(): + if name not in in_names: + raise ValueError(name + "does not exist in Op's inputs.") + var = scope.new_var(name).get_tensor() + var.set_dims(value.shape) + var.set(value, place) + + # run forward op + for out_name in out_names: + scope.new_var(out_name) + forward_op.infer_shape(scope) + forward_op.run(scope, ctx) + + # set output var's shape + # set output grad to ones + for name in out_names: + out_tensor = scope.find_var(name).get_tensor() + grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() + grad_tensor.set_dims(out_tensor.shape()) + data = numpy.ones(out_tensor.shape(), dtype=numpy.float32) + grad_tensor.set(data, place) + + # run backward op + for name in backward_op.outputs(): + scope.new_var(name) + backward_op.infer_shape(scope) + backward_op.run(scope, ctx) + + outs = [ + numpy.array(scope.find_var(name).get_tensor()) + for name in grad_names + ] + return outs + + def compare_grad(self, forward_op, input_value): + """ Compare the input gradients between CPU and GPU for the given forward + operator. + + :param forward_op: forward operator + :type forward_op: Operator + :param input_value: input values. + :type input_value: dict{string:numpy.array} + :raises: AssertionError, there is different gradient value. + """ + backward_op = core.Operator.backward(forward_op, set()) + # return if not compile with GPU or not implementing GPU kernel + if not (core.is_compile_gpu() and backward_op.support_gpu()): + return + + outputs = backward_op.outputs() + out_names = [item for k in outputs for item in outputs[k]] + cpu_grads = self.__get_gradient(forward_op, backward_op, input_value, + out_names, core.CPUPlace()) + gpu_grads = self.__get_gradient(forward_op, backward_op, input_value, + out_names, core.GPUPlace(0)) + + for c_grad, g_grad, name in itertools.izip(cpu_grads, gpu_grads, + out_names): + self.assertTrue( + numpy.allclose( + c_grad, g_grad, atol=1e-4), + "output name: " + name + " has diff") + + def __assert_is_close(self, numeric_grads, analytic_grads, names, + max_relative_error, msg_prefix): + """Use relative error for the comparison. + + :param numeric_grads: the numerical graidents. + :type numeric_grads: a list of numpy.array + :param analytic_grads: the analytical graidents. + :type analytic_grads: a list of numpy.array + :param name: the names of gradients, used to print for debug. + :type names: a list of string + :param msg_prefix: string info, used to print for debug. + :type msf_prefix: string + """ + for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): abs_a = numpy.abs(a) # if abs_a is nearly zero, then use abs error for a, not relative # error. @@ -150,107 +252,32 @@ class GradientChecker(unittest.TestCase): if no_grad_set is None: no_grad_set = set() - tmp_outs = forward_op.temp_outputs() - no_tmp_out = filter(lambda name: name not in tmp_outs, - forward_op.outputs()) + no_tmp_out = forward_op.no_intermediate_outputs() if len(no_tmp_out) != 1: raise ValueError("non temp out_names should be 1") - in_names = forward_op.inputs() + inputs = forward_op.inputs() + in_names = [item for k in inputs for item in inputs[k]] for no_grad in no_grad_set: if no_grad not in in_names: raise ValueError("no_grad should be in in_names") - backward_op = core.Operator.backward(forward_op, no_grad_set) places = [core.CPUPlace()] if not only_cpu and core.is_compile_gpu() and backward_op.support_gpu(): places.append(core.GPUPlace(0)) - numeric_grad = dict() - # get numeric gradient - for check_name in inputs_to_check: - numeric_grad[check_name] = \ - get_numeric_gradient(forward_op, input_vars, output_name, - check_name) + # get numerical gradients + numeric_grads = [ + get_numeric_gradient(forward_op, input_vars, output_name, name) + for name in inputs_to_check + ] - # get operator gradient according to different device + check_names = [grad_var_name(name) for name in inputs_to_check] for place in places: - scope = core.Scope() - ctx = core.DeviceContext.create(place) - - # create input var and set value - for name, value in input_vars.iteritems(): - if name not in in_names: - raise ValueError(name + " not in op.inputs_") - var = scope.new_var(name).get_tensor() - var.set_dims(value.shape) - var.set(value, place) - - # create output var - for out_name in forward_op.outputs(): - scope.new_var(out_name).get_tensor() - - # infer the shape of output var and compute/set value of output var - forward_op.infer_shape(scope) - forward_op.run(scope, ctx) - - # create output grad var - # set shape as the output var - # set value of this grad to ones - for name in forward_op.outputs(): - out_tensor = scope.find_var(name).get_tensor() - grad_tensor = scope.new_var(grad_var_name(name)).get_tensor() - grad_tensor.set_dims(out_tensor.shape()) - data = 1.0 * numpy.ones(out_tensor.shape()) - grad_tensor.set(data, place) - - # create input grad var - for name in backward_op.outputs(): - scope.new_var(name).get_tensor() - - # infer the shape of input gradient var and compute/set it's value - # with backward op - backward_op.infer_shape(scope) - backward_op.run(scope, ctx) - - self.assert_is_close(numeric_grad, scope, max_relative_error, - "Gradient Check On %s" % str(place)) - - -if __name__ == '__main__': - - class GetNumericGradientTest(unittest.TestCase): - def test_add_op(self): - add_op = Operator('add_two', X="X", Y="Y", Out="Z") - x = numpy.random.random((10, 1)).astype("float32") - y = numpy.random.random((10, 1)).astype("float32") - - arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') - self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-2) - - def test_softmax_op(self): - def stable_softmax(x): - """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - numpy.max(x) - exps = numpy.exp(shiftx) - return exps / numpy.sum(exps) - - def label_softmax_grad(Y, dY): - dX = Y * 0.0 - for i in range(Y.shape[0]): - d = numpy.dot(Y[i, :], dY[i, :]) - dX[i, :] = Y[i, :] * (dY[i, :] - d) - return dX - - softmax_op = Operator("softmax", X="X", Y="Y") - - X = numpy.random.random((2, 2)).astype("float32") - Y = numpy.apply_along_axis(stable_softmax, 1, X) - dY = numpy.ones(Y.shape) - dX = label_softmax_grad(Y, dY) - - arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') - numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) - - unittest.main() + # get analytical gradients according to different device + analytic_grads = self.__get_gradient(forward_op, backward_op, + input_vars, check_names, place) + self.__assert_is_close(numeric_grads, analytic_grads, check_names, + max_relative_error, + "Gradient Check On %s" % str(place)) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index dd65e0f2dc23d3f657ff16c55fb297dae210b2d7..3bc05a0feccbbd3d5e7852d85bd3dc8edaccfd07 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -64,7 +64,8 @@ class OpTestMeta(type): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue( - numpy.allclose(actual, expect), + numpy.allclose( + actual, expect, atol=1e-05), "output name: " + out_name + "has diff") obj.test_all = test_all diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py index c0237830647371e14b755953345965a3eac7bfd2..0def484eddb88604398ee10390d3f28058714a57 100644 --- a/python/paddle/v2/framework/tests/test_add_two_op.py +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -19,14 +19,5 @@ class TestAddOp(unittest.TestCase): self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} -class TestAddGradOp(unittest.TestCase): - def test_add_grad(self): - op = Operator('add_two', X="X", Y="Y", Out="Out") - backward_op = core.Operator.backward(op, set()) - self.assertEqual(backward_op.type(), "add_two_grad") - expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).''' - self.assertEqual(expected, str(backward_op)) - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 4815192e255c6e0429db3f50918a76a773b30131..d4277f2a42ce2e66e37405ccd3b2ee444d403d1a 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -8,9 +8,8 @@ class TestCrossEntropy(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): - # TODO this unit test is not passed self.type = "onehot_cross_entropy" - batch_size = 100 + batch_size = 30 class_num = 10 X = numpy.random.random((batch_size, class_num)).astype("float32") label = 5 * numpy.ones(batch_size).astype("int32") @@ -22,9 +21,9 @@ class TestCrossEntropy(unittest.TestCase): class CrossEntropyGradOpTest(GradientChecker): - def test_softmax_grad(self): + def test_check_grad(self): op = create_op("onehot_cross_entropy") - batch_size = 100 + batch_size = 30 class_num = 10 inputs = { "X": numpy.random.uniform( diff --git a/python/paddle/v2/framework/tests/test_gradient_checker.py b/python/paddle/v2/framework/tests/test_gradient_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b315120862bea284e067070492dcdfbb661081 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gradient_checker.py @@ -0,0 +1,43 @@ +import unittest +import numpy +from paddle.v2.framework.op import Operator +from gradient_checker import GradientChecker +from gradient_checker import get_numeric_gradient + + +class GetNumericGradientTest(unittest.TestCase): + def test_add_op(self): + add_op = Operator('add_two', X="X", Y="Y", Out="Z") + x = numpy.random.random((10, 1)).astype("float32") + y = numpy.random.random((10, 1)).astype("float32") + + arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X') + self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4) + + def test_softmax_op(self): + def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - numpy.max(x) + exps = numpy.exp(shiftx) + return exps / numpy.sum(exps) + + def label_softmax_grad(Y, dY): + dX = Y * 0.0 + for i in range(Y.shape[0]): + d = numpy.dot(Y[i, :], dY[i, :]) + dX[i, :] = Y[i, :] * (dY[i, :] - d) + return dX + + softmax_op = Operator("softmax", X="X", Y="Y") + + X = numpy.random.random((2, 2)).astype("float32") + Y = numpy.apply_along_axis(stable_softmax, 1, X) + dY = numpy.ones(Y.shape) + dX = label_softmax_grad(Y, dY) + + arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X') + numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py index b5d52b90567bcd0c9f376147145d8638049f7bab..f32b3160d651a290823223c46c45bb3b6950a505 100644 --- a/python/paddle/v2/framework/tests/test_mean_op.py +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -1,5 +1,6 @@ import unittest from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op import numpy as np @@ -12,5 +13,12 @@ class TestMeanOp(unittest.TestCase): self.outputs = {'Out': np.mean(self.inputs['X'])} +class MeanGradOpTest(GradientChecker): + def test_normal(self): + op = create_op("mean") + inputs = {"X": np.random.random((10, 10)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index ec0ac99156a546dd3fb7b27778032bece38ab5a9..ee0d81a64efcb81bae8b11b856c201a86da274e9 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta class TestMulOp(unittest.TestCase): @@ -15,5 +16,19 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class MulGradOpTest(GradientChecker): + def test_mul(self): + op = create_op("mul") + inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + # mul op will enlarge the relative error + self.check_grad( + op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) + + +# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/framework/tests/test_net.py index cc7f09e7155f5b1afa47fc4133b71ae3676b7436..9339cf28dabc95b46b958777200fb1db9dcf284f 100644 --- a/python/paddle/v2/framework/tests/test_net.py +++ b/python/paddle/v2/framework/tests/test_net.py @@ -6,8 +6,8 @@ import unittest def fc(X, W, Y): ret_v = core.Net.create() - ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation")) - ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y)) + ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation")) + ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y)) ret_v.complete_add_op(True) return ret_v @@ -16,21 +16,21 @@ class TestNet(unittest.TestCase): def test_net_all(self): net = core.Net.create() op1 = Operator("add_two", X="X", Y="Y", Out="Out") - net.add_op(op1) + net.append_op(op1) net2 = core.Net.create() - net2.add_op(fc(X="X", W="w", Y="fc.out")) + net2.append_op(fc(X="X", W="w", Y="fc.out")) net2.complete_add_op(True) - net.add_op(net2) + net.append_op(net2) net.complete_add_op(True) expected = ''' -Op(plain_net), inputs:(W, X, Y), outputs:(Out, fc.out, pre_activation). - Op(add_two), inputs:(X, Y), outputs:(Out). - Op(plain_net), inputs:(W, X), outputs:(fc.out, pre_activation). - Op(plain_net), inputs:(W, X), outputs:(fc.out, pre_activation). - Op(mul), inputs:(X, W), outputs:(pre_activation). - Op(sigmoid), inputs:(pre_activation), outputs:(fc.out). +Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}. + Op(add_two), inputs:{X[X], Y[Y]}, outputs:{Out[Out]}. + Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}. + Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}. + Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}. + Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Y[fc.out]}. ''' self.assertEqual(expected, "\n" + str(net)) diff --git a/python/paddle/v2/framework/tests/test_operator.py b/python/paddle/v2/framework/tests/test_operator.py index 4f164e1a69e3fd0409f9b575a8bd9b4e423b486b..1abc4eeb57bcedc81e34b0e156048ee4f5cfdc2d 100644 --- a/python/paddle/v2/framework/tests/test_operator.py +++ b/python/paddle/v2/framework/tests/test_operator.py @@ -1,9 +1,7 @@ import unittest import paddle.v2.framework.op as op import paddle.v2.framework.core as core -import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 -import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 -import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2 +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 class TestGetAllProtos(unittest.TestCase): @@ -17,7 +15,7 @@ class TestGetAllProtos(unittest.TestCase): class TestOpDescCreationMethod(unittest.TestCase): def test_plain_input_output(self): - op_proto = op_proto_pb2.OpProto() + op_proto = framework_pb2.OpProto() op_proto.type = "test" ipt = op_proto.inputs.add() ipt.name = "X" @@ -37,25 +35,32 @@ class TestOpDescCreationMethod(unittest.TestCase): method = op.OpDescCreationMethod(op_proto) output = method(X="a", Y="b", Z="c") - - expected = op_desc_pb2.OpDesc() + expected = framework_pb2.OpDesc() expected.type = "test" - expected.inputs.extend(["a", "b"]) - expected.outputs.append("c") + ipt_0 = expected.inputs.add() + ipt_0.parameter = "X" + ipt_0.arguments.extend(["a"]) + ipt_1 = expected.inputs.add() + ipt_1.parameter = 'Y' + ipt_1.arguments.extend(['b']) + opt = expected.outputs.add() + opt.parameter = "Z" + opt.arguments.extend(["c"]) + self.assertEqual(expected, output) def test_multiple_input_plain_output(self): - op_proto = op_proto_pb2.OpProto() + op_proto = framework_pb2.OpProto() op_proto.type = "fc" ipt = op_proto.inputs.add() ipt.name = "X" ipt.comment = "" - ipt.multiple = True + ipt.duplicable = True ipt = op_proto.inputs.add() ipt.name = "W" ipt.comment = "" - ipt.multiple = True + ipt.duplicable = True ipt = op_proto.inputs.add() ipt.name = "b" @@ -70,30 +75,50 @@ class TestOpDescCreationMethod(unittest.TestCase): method = op.OpDescCreationMethod(op_proto) generated1 = method(X="x", W="w", b="b", Y="y") - expected1 = op_desc_pb2.OpDesc() - expected1.inputs.extend(['x', 'w', 'b']) - expected1.outputs.extend(['y']) + expected1 = framework_pb2.OpDesc() + tmp = expected1.inputs.add() + tmp.parameter = "X" + tmp.arguments.extend(['x']) + + tmp = expected1.inputs.add() + tmp.parameter = 'W' + tmp.arguments.extend(['w']) + + tmp = expected1.inputs.add() + tmp.parameter = 'b' + tmp.arguments.extend(['b']) + + tmp = expected1.outputs.add() + tmp.parameter = 'Y' + tmp.arguments.extend(['y']) expected1.type = 'fc' - attr = expected1.attrs.add() - attr.name = 'input_format' - attr.type = attribute_pb2.INTS - attr.ints.extend([0, 1, 2, 3]) self.assertEqual(expected1, generated1) generated2 = method( X=['x1', 'x2', 'x3'], b='b', W=['w1', 'w2', 'w3'], Y='y') - expected2 = op_desc_pb2.OpDesc() - expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b']) - expected2.outputs.extend(['y']) + expected2 = framework_pb2.OpDesc() + + tmp = expected2.inputs.add() + tmp.parameter = "X" + tmp.arguments.extend(['x1', 'x2', 'x3']) + + tmp = expected2.inputs.add() + tmp.parameter = 'W' + tmp.arguments.extend(['w1', 'w2', 'w3']) + + tmp = expected2.inputs.add() + tmp.parameter = 'b' + tmp.arguments.extend(['b']) + + tmp = expected2.outputs.add() + tmp.parameter = 'Y' + tmp.arguments.extend(['y']) + expected2.type = 'fc' - attr = expected2.attrs.add() - attr.name = 'input_format' - attr.type = attribute_pb2.INTS - attr.ints.extend([0, 3, 6, 7]) self.assertEqual(expected2, generated2) def test_attrs(self): - op_proto = op_proto_pb2.OpProto() + op_proto = framework_pb2.OpProto() op_proto.type = "test" ipt = op_proto.inputs.add() ipt.name = 'X' @@ -105,12 +130,12 @@ class TestOpDescCreationMethod(unittest.TestCase): attr.comment = "" attr.type = type - __add_attr__("int_attr", attribute_pb2.INT) - __add_attr__("float_attr", attribute_pb2.FLOAT) - __add_attr__("string_attr", attribute_pb2.STRING) - __add_attr__("ints_attr", attribute_pb2.INTS) - __add_attr__("floats_attr", attribute_pb2.FLOATS) - __add_attr__("strings_attr", attribute_pb2.STRINGS) + __add_attr__("int_attr", framework_pb2.INT) + __add_attr__("float_attr", framework_pb2.FLOAT) + __add_attr__("string_attr", framework_pb2.STRING) + __add_attr__("ints_attr", framework_pb2.INTS) + __add_attr__("floats_attr", framework_pb2.FLOATS) + __add_attr__("strings_attr", framework_pb2.STRINGS) op_proto.comment = "" self.assertTrue(op_proto.IsInitialized()) @@ -126,76 +151,52 @@ class TestOpDescCreationMethod(unittest.TestCase): floats_attr=[0.2, 3.2, 4.5], strings_attr=["a", "b", "c"]) - expected = op_desc_pb2.OpDesc() + expected = framework_pb2.OpDesc() expected.type = "test" - expected.inputs.extend(['a']) + + ipt = expected.inputs.add() + ipt.parameter = "X" + ipt.arguments.extend(['a']) + attr = expected.attrs.add() attr.name = "int_attr" - attr.type = attribute_pb2.INT + attr.type = framework_pb2.INT attr.i = 10 attr = expected.attrs.add() attr.name = "float_attr" - attr.type = attribute_pb2.FLOAT + attr.type = framework_pb2.FLOAT attr.f = 3.2 attr = expected.attrs.add() attr.name = "string_attr" - attr.type = attribute_pb2.STRING + attr.type = framework_pb2.STRING attr.s = "test_str" attr = expected.attrs.add() attr.name = "ints_attr" - attr.type = attribute_pb2.INTS + attr.type = framework_pb2.INTS attr.ints.extend([0, 1, 2, 3, 4]) attr = expected.attrs.add() attr.name = 'floats_attr' - attr.type = attribute_pb2.FLOATS + attr.type = framework_pb2.FLOATS attr.floats.extend([0.2, 3.2, 4.5]) attr = expected.attrs.add() attr.name = 'strings_attr' - attr.type = attribute_pb2.STRINGS + attr.type = framework_pb2.STRINGS attr.strings.extend(['a', 'b', 'c']) self.assertEqual(expected, generated) - def test_input_temporary_output(self): - op_proto = op_proto_pb2.OpProto() - op_proto.type = "test" - out = op_proto.outputs.add() - out.name = "OUT" - out.comment = "" - - out = op_proto.outputs.add() - out.name = "TMP" - out.comment = "" - out.temporary = True - - out = op_proto.outputs.add() - out.name = "OUT2" - out.comment = "" - op_proto.comment = "" - - method = op.OpDescCreationMethod(op_proto) - generated = method(OUT="a", OUT2="b") - desc = op_desc_pb2.OpDesc() - desc.outputs.extend(["a", core.var_names.temp(), "b"]) - desc.type = "test" - attr = desc.attrs.add() - attr.name = "temporary_index" - attr.type = attribute_pb2.INTS - attr.ints.append(2) - self.assertEqual(generated, desc) - class TestOpCreations(unittest.TestCase): def test_all(self): add_op = op.Operator("add_two", X="a", Y="b", Out="z") self.assertIsNotNone(add_op) # Invoke C++ DebugString() - self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).', + self.assertEqual('Op(add_two), inputs:{X[a], Y[b]}, outputs:{Out[z]}.', str(add_op)) diff --git a/python/paddle/v2/framework/tests/test_protobuf.py b/python/paddle/v2/framework/tests/test_protobuf.py index 69e98e2f250a9df23b25e7e2043af29f87c996a0..848a396b3b6eec57d500b464780b64f339b09e94 100644 --- a/python/paddle/v2/framework/tests/test_protobuf.py +++ b/python/paddle/v2/framework/tests/test_protobuf.py @@ -1,11 +1,10 @@ -import paddle.v2.framework.proto.op_proto_pb2 as op_proto_lib -import paddle.v2.framework.proto.attribute_pb2 as attr_type_lib +import paddle.v2.framework.proto.framework_pb2 as framework_pb2 import unittest class TestFrameworkProto(unittest.TestCase): def test_all(self): - op_proto = op_proto_lib.OpProto() + op_proto = framework_pb2.OpProto() ipt0 = op_proto.inputs.add() ipt0.name = "a" ipt0.comment = "the input of cosine op" @@ -19,7 +18,7 @@ class TestFrameworkProto(unittest.TestCase): attr = op_proto.attrs.add() attr.name = "scale" attr.comment = "scale of cosine op" - attr.type = attr_type_lib.FLOAT + attr.type = framework_pb2.FLOAT op_proto.type = "cos" self.assertTrue(op_proto.IsInitialized()) diff --git a/python/paddle/v2/framework/tests/test_recurrent_op.py b/python/paddle/v2/framework/tests/test_recurrent_op.py index 0db66cc4e181fde10f161a323ea749fd84a5f963..d6000ab9f9d5b969f96128b183f48d49000c8a5e 100644 --- a/python/paddle/v2/framework/tests/test_recurrent_op.py +++ b/python/paddle/v2/framework/tests/test_recurrent_op.py @@ -2,7 +2,7 @@ import logging import paddle.v2.framework.core as core import unittest import numpy as np -from paddle.v2.framework.op import Operator +from paddle.v2.framework.op import Operator, RecurrentOp def py_sigmoid(x): @@ -98,11 +98,11 @@ class TestRecurrentOp(unittest.TestCase): def forward(self): self.scope = core.Scope() self.create_global_variables() + self.create_rnn_op() self.create_step_net() - rnn_op = self.create_rnn_op() ctx = core.DeviceContext.create(core.CPUPlace()) - rnn_op.infer_shape(self.scope) - rnn_op.run(self.scope, ctx) + self.rnnop.infer_shape(self.scope) + self.rnnop.run(self.scope, ctx) return np.array(self.scope.find_var("h").get_tensor()) def create_global_variables(self): @@ -128,8 +128,7 @@ class TestRecurrentOp(unittest.TestCase): def create_rnn_op(self): # create RNNOp - rnnop = Operator( - "recurrent_op", + self.rnnop = RecurrentOp( # inputs inlinks=["x"], boot_memories=["h_boot"], @@ -142,22 +141,18 @@ class TestRecurrentOp(unittest.TestCase): outlink_alias=["h@alias"], pre_memories=["h@pre"], memories=["h@alias"]) - return rnnop def create_step_net(self): - var = self.scope.new_var("stepnet") - stepnet = var.get_net() - - # x_fc_op = Operator("fc", X="x@alias", W="W", Y="Wx") - # h_fc_op = Operator("fc", X="h@pre", W="U", Y="Uh") + stepnet = core.Net.create() x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx") h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh") sum_op = Operator("add_two", X="Wx", Y="Uh", Out="sum") sig_op = Operator("sigmoid", X="sum", Y="h@alias") for op in [x_fc_op, h_fc_op, sum_op, sig_op]: - stepnet.add_op(op) + stepnet.append_op(op) stepnet.complete_add_op(True) + self.rnnop.set_stepnet(stepnet) def test_forward(self): print 'test recurrent op forward' diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index f8521eb517057fbeb104b28af7da4fffe54f37de..45d569da29d13cf8e2a3cb9d67c2d01e8b365453 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestRowwiseAddOp(unittest.TestCase): @@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class RowwiseAddGradOpTest(GradientChecker): + def test_rowwise_add(self): + op = create_op("rowwise_add") + inputs = { + "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"), + "b": np.random.uniform(0.1, 1, [10]).astype("float32") + } + self.check_grad(op, inputs, set(["X", "b"]), "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2a57a41ed8b718fd420062ba68e853a4861b7359..273c2e5ab1a84d12621fe9568c4cf22073b6aed4 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestSigmoidOp(unittest.TestCase): @@ -8,12 +9,20 @@ class TestSigmoidOp(unittest.TestCase): def setUp(self): self.type = "sigmoid" - self.inputs = {'X': np.random.random((32, 100)).astype("float32")} + self.inputs = {'X': np.random.random((15, 31)).astype("float32")} self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} -#class TestSigmoidGradOp(unittest.TestCase): -#TODO(qingqing) add unit test +class TestSigmoidGradOp(GradientChecker): + def test_grad(self): + op = create_op("sigmoid") + inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")} + # compare gpu and cpu results for backward op. + # this test will be skiped if only compiling CPU version. + self.compare_grad(op, inputs) + # check gradients + self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index d0f18e4b6611fa56654e7f2a0144758339cb9e19..97e844b92c77a7c58105dc5df2b4092fa5571d6f 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -57,7 +57,7 @@ def text_file(path): return reader -def recordio_local(paths, buf_size=100): +def recordio(paths, buf_size=100): """ Creates a data reader from given RecordIO file paths separated by ",", glob pattern is supported. @@ -67,15 +67,19 @@ def recordio_local(paths, buf_size=100): import recordio as rec import paddle.v2.reader.decorator as dec + import cPickle as pickle def reader(): - a = ','.join(paths) - f = rec.reader(a) + if isinstance(paths, basestring): + path = paths + else: + path = ",".join(paths) + f = rec.reader(path) while True: r = f.read() if r is None: break - yield r + yield pickle.loads(r) f.close() return dec.buffered(reader, buf_size) diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index 359f3eeefbe8efeb343cc875c707c9251a7087fb..cf190aa6645f9a5bed891a3a47c03efa03813d65 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -34,5 +34,27 @@ class TestTextFile(unittest.TestCase): self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) +class TestRecordIO(unittest.TestCase): + def do_test(self, path): + reader = paddle.v2.reader.creator.recordio(path) + idx = 0 + for e in reader(): + if idx == 0: + self.assertEqual(e, (1, 2, 3)) + elif idx == 1: + self.assertEqual(e, (4, 5, 6)) + idx += 1 + self.assertEqual(idx, 2) + + def test_recordIO(self): + self.do_test( + os.path.join( + os.path.dirname(__file__), "test_reader_recordio.dat")) + self.do_test([ + os.path.join( + os.path.dirname(__file__), "test_reader_recordio.dat") + ]) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/reader/tests/test_reader_recordio.dat b/python/paddle/v2/reader/tests/test_reader_recordio.dat new file mode 100644 index 0000000000000000000000000000000000000000..a99a35bb829e066c4845d0b85b96cd1eb3a12491 Binary files /dev/null and b/python/paddle/v2/reader/tests/test_reader_recordio.dat differ diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 9c4dd5f25083d210bbd218a85d8dbb3cce2c3d0e..0654a301049dcb347b79879076a869a0c14a07ae 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -27,16 +27,24 @@ class SGD(object): SGD Trainer combines data reader, network topolopy and update_equation together to train/test a neural network. - :param update_equation: The optimizer object. - :type update_equation: paddle.v2.optimizer.Optimizer :param cost: Target cost that neural network should be optimized. :type cost: paddle.v2.config_base.Layer :param parameters: The parameters dictionary. :type parameters: paddle.v2.parameters.Parameters + :param update_equation: The optimizer object. + :type update_equation: paddle.v2.optimizer.Optimizer :param extra_layers: Some layers in the neural network graph are not in the path of cost layer. - :param pserver_spec: pserver location, eg: localhost:3000 :type extra_layers: paddle.v2.config_base.Layer + :param is_local: Whether trainning locally + :type is_local: bool + :param pserver_spec: comma string for pserver location, + eg:127.10.0.10:3000,127.10.0.11:3000, + and this parameter is only used for fault + tolerant mode cluster training. + :type pserver_spec: string + :param use_etcd: Whether using etcd pserver. + :param use_etcd: bool """ def __init__(self, diff --git a/python/requirements.txt b/python/requirements.txt index 3df822bd76d2a64a0a35f84b4ec309ce7150c221..e19453c25da1ec78773c00a72b8e517b0d798fff 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,7 +1,7 @@ requests==2.9.2 numpy>=1.12 protobuf==3.1 -recordio +recordio>=0.1.0 matplotlib rarfile scipy>=0.19.0 diff --git a/python/setup.py.in b/python/setup.py.in index 38728aa2fd77cf3c882479ed83e99688b9ffa541..87b3823e52604b889cdee76bc696a1ae9b9de802 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -23,6 +23,19 @@ with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: setup_requires+=["opencv-python"] +# the prefix is sys.prefix which should always be usr +paddle_bin_dir = 'opt/paddle/bin' +paddle_bins = ['${PADDLE_BINARY_DIR}/paddle/scripts/paddle_usage', + '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_trainer', + '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_merge_model', + '${PADDLE_BINARY_DIR}/paddle/pserver/paddle_pserver_main', + '${PADDLE_BINARY_DIR}/paddle/scripts/paddle'] + +paddle_rt_lib_dir = 'lib' +paddle_rt_libs = ['${WARPCTC_LIBRARIES}'] +if '${MKL_SHARED_LIBS}'!= '': + paddle_rt_libs += '${MKL_SHARED_LIBS}'.split(';') + setup(name='paddlepaddle', version='${PADDLE_VERSION}', description='Parallel Distributed Deep Learning', @@ -40,11 +53,7 @@ setup(name='paddlepaddle', 'paddle.v2.framework.proto': '${PADDLE_BINARY_DIR}/paddle/framework', 'py_paddle': '${PADDLE_SOURCE_DIR}/paddle/py_paddle' }, - scripts=['${PADDLE_BINARY_DIR}/paddle/scripts/paddle'], + scripts=paddle_bins, distclass=BinaryDistribution, - data_files=[('/usr/local/opt/paddle/bin', - ['${PADDLE_BINARY_DIR}/paddle/scripts/paddle_usage', - '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_trainer', - '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_merge_model', - '${PADDLE_BINARY_DIR}/paddle/pserver/paddle_pserver_main'])] + data_files=[(paddle_rt_lib_dir, paddle_rt_libs)] )