diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 980a97a07c996eca2e8c126a6ad5ab7f340fa1e5..2ca988c406ae2987e26ca37dbc17cc0a2af43743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,10 +17,14 @@ - id: detect-private-key files: (?!.*third_party)^.*$ | (?!.*book)^.*$ - id: end-of-file-fixer -- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29 +- repo: local hooks: - - id: clang-formater + - id: clang-format + name: clang-format + description: Format files with ClangFormat. + entry: clang-format -i + language: system + files: \.(c|cc|cxx|cpp|h|hpp|hxx)$ - repo: https://github.com/PaddlePaddle/pre-commit-golang sha: 8337620115c25ff8333f1b1a493bd031049bd7c0 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index c7d743e193e7d32dbc0b56f3bcb05b6c61f85f1d..b174831109372cb014741d63032fa6a470e74042 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,8 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) -option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) -option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) diff --git a/Dockerfile b/Dockerfile index 8cfb16928c95dcbfac08383d32562ff67933d873..5dd9b0be4f7e0a304108abfdfb089fea4faa4d38 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ RUN apt-get update && \ git python-pip python-dev openssh-server bison \ wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ - python-numpy python-matplotlib gcc g++ \ + python-numpy python-matplotlib gcc-4.8 g++-4.8 \ automake locales clang-format-3.8 swig doxygen cmake \ liblapack-dev liblapacke-dev libboost-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev \ diff --git a/README.md b/README.md index 2a6beeb342b34f8e91ef509d7d41f286a666480c..b9793c3eab5d40c28f01cc67ad607b97261b3235 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ We provide [English](http://doc.paddlepaddle.org/develop/doc/) and - [Deep Learning 101](http://book.paddlepaddle.org/index.html) - You might want to start from the this online interactive book that can run in Jupyter Notebook. + You might want to start from this online interactive book that can run in Jupyter Notebook. - [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index eff15de73f23db6dea3a7b79006bfec90d712ae5..25c6b4ef52d3f8ebff1572ae8d348be7c577c08c 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -20,34 +20,30 @@ INCLUDE(ExternalProject) SET(MKLDNN_PROJECT "extern_mkldnn") SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn) -SET(MKLDNN_INSTALL_ROOT ${CMAKE_INSTALL_PREFIX}) -IF(NOT "$ENV{HOME}" STREQUAL "/root") - SET(MKLDNN_INSTALL_ROOT "$ENV{HOME}") -ENDIF() - -SET(MKLDNN_INSTALL_DIR "${MKLDNN_INSTALL_ROOT}/opt/paddle/third_party/mkldnn") -SET(MKLDNN_INCLUDE_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) +SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) +SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) -IF(WIN32) - MESSAGE(WARNING "It is not supported compiling with mkldnn in windows Paddle yet." - "Force WITH_MKLDNN=OFF") - SET(WITH_MKLDNN OFF) +IF(WIN32 OR APPLE) + MESSAGE(WARNING + "Windows or Mac is not supported with MKLDNN in Paddle yet." + "Force WITH_MKLDNN=OFF") + SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in Windows and MacOS" FORCE) return() -ELSE(WIN32) - SET(MKLDNN_LIBRARY "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE) - MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") - SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) - #SET(CMAKE_MACOSX_RPATH 1) # hold for MacOS - SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") -ENDIF(WIN32) +ENDIF() + +SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE) +MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") +SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") -INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIR}) +INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) IF(${CBLAS_PROVIDER} STREQUAL "MKLML") SET(MKLDNN_DEPENDS ${MKLML_PROJECT}) SET(MKLDNN_MKLROOT ${MKLML_ROOT}) SET(MKLDNN_IOMP_LIB ${MKLML_IOMP_LIB}) SET(MKLDNN_IOMP_DIR ${MKLML_LIB_DIR}) + MESSAGE(STATUS "Build MKLDNN with ${MKLDNN_MKLROOT}") ENDIF() ExternalProject_Add( @@ -57,16 +53,15 @@ ExternalProject_Add( GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_TAG "v0.9" PREFIX ${MKLDNN_SOURCES_DIR} - CONFIGURE_COMMAND mkdir -p /build - BUILD_COMMAND cd /build - && cmake .. -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} -DMKLROOT=${MKLDNN_MKLROOT} - && $(MAKE) - INSTALL_COMMAND cd /build && $(MAKE) install UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} + CMAKE_ARGS -DMKLROOT=${MKLDNN_MKLROOT} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} + -DMKLROOT:PATH=${MKLDNN_MKLROOT} ) ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIBRARY}) +SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB}) ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) -MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIBRARY}") +MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIB}") LIST(APPEND external_project_dependencies mkldnn) diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index 3f940756a4abb79aba7d3561db19db8532a0b673..17a1ca4ed04dce85ae3c7fdd5f22d6eeed03db59 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -16,19 +16,23 @@ IF(NOT ${WITH_MKLML}) return() ENDIF(NOT ${WITH_MKLML}) +IF(WIN32 OR APPLE) + MESSAGE(WARNING + "Windows or Mac is not supported with MKLML in Paddle yet." + "Force WITH_MKLML=OFF") + SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE) + return() +ENDIF() + INCLUDE(ExternalProject) SET(MKLML_PROJECT "extern_mklml") -SET(MKLML_VER "mklml_lnx_2018.0.20170425") +SET(MKLML_VER "mklml_lnx_2018.0.20170720") SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz") SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml") SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") -SET(MKLML_DST_DIR "opt/paddle/third_party/mklml") -SET(MKLML_INSTALL_ROOT "${CMAKE_INSTALL_PREFIX}") -IF(NOT "$ENV{HOME}" STREQUAL "/root") - SET(MKLML_INSTALL_ROOT "$ENV{HOME}") -ENDIF() - +SET(MKLML_DST_DIR "mklml") +SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER}) SET(MKLML_INC_DIR ${MKLML_ROOT}/include) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index ef31c252038ce18655913c0f41343fe6dc7dbb86..d00a9bb3a30cfb16623e073414088059481c3e1a 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -9,6 +9,11 @@ function(CheckCompilerCXX11Flag) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") endif() + # TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem. + # Use Debug mode instead for now. + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9) + set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE) + endif() elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # Apple Clang is a different compiler than upstream Clang which havs different version numbers. diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index ec7f1446cfb74842af7d0c7152bebf58619f3861..372272a53c12c314fc80eebbce5eae9fcabc55ba 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -104,6 +104,11 @@ cross_channel_norm ------------------ .. autoclass:: paddle.v2.layer.cross_channel_norm :noindex: + +row_l2_norm +----------- +.. autoclass:: paddle.v2.layer.row_l2_norm + :noindex: Recurrent Layers ================ @@ -320,6 +325,11 @@ scaling .. autoclass:: paddle.v2.layer.scaling :noindex: +clip +---- +.. autoclass:: paddle.v2.layer.clip + :noindex: + slope_intercept --------------- .. autoclass:: paddle.v2.layer.slope_intercept diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 4b06966fba2bc9f92756be0cb8110bbcd5272423..f8a88cf317aee6c5dd25e4cc25d588c6c50fcbce 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,7 +15,6 @@ if(Boost_FOUND) add_subdirectory(platform) add_subdirectory(framework) add_subdirectory(operators) - add_subdirectory(pybind) endif() if(WITH_C_API) diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index c53a5636829cab9d575f58cc2326cb3efe383e1c..7ad8a39768a064140a08c912a5a467bc24a12adf 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, real alpha = 1.0f; real beta = 1.0f; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; + + int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size; + if (batch_size > 1024 && g_cudnn_lib_version < 6000) { + LOG(INFO) << " To process current batch data with size " << batch_size + << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000." + << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED," + << " just recompile PaddlePaddle with cuDNN >= 6000, replacing" + << " current version " << g_cudnn_lib_version; + } CHECK_CUDNN( dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, mode, diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 12a3a00bba35d476fca9c9fb47ac20b87e6f53f2..9c39430835d37d5dfbe4031f29e5a6216ed8b67f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -31,8 +31,14 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) -cc_library(net SRCS net.cc DEPS op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net) - -cc_library(backward SRCS backward.cc DEPS net) +cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward) +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python backward + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_op) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0da11b91a7fe4a98e0832f70095c3200956ff001..9730fdd18bcf2f5011657876811a98cc4cbca859 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -14,8 +14,8 @@ #include "paddle/framework/backward.h" #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace framework { @@ -32,7 +32,7 @@ static bool AllInSet(const std::vector& names, } static std::shared_ptr NOP() { - auto net_op = std::make_shared(); + auto net_op = std::make_shared(); net_op->type_ = "@NOP@"; net_op->CompleteAddOp(); return net_op; @@ -77,11 +77,11 @@ std::shared_ptr BackwardRecursive( } // Returned gradient network - auto net = std::make_shared(); + auto net = std::make_shared(); if (forwardOp.IsNetOp()) { // Because forwardOp is a net op, it can static_cast. - auto& forwardNet = static_cast(forwardOp); + auto& forwardNet = static_cast(forwardOp); // Map from output gradient variable name to operator's indices in backward // net. That operator generates that variable. @@ -168,6 +168,9 @@ std::shared_ptr Backward( std::unordered_set no_grad_names; no_grad_names.reserve(no_grad_vars.size()); + no_grad_names.insert(OperatorBase::EMPTY_VAR_NAME() + + OperatorBase::GRAD_VAR_SUFFIX()); + for (auto& name : no_grad_vars) { no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index b095c2c3d5dbf21b5ea70e17475a4aaad9b1db44..8adf7e4365d6d044e551c9e66101c7ae023e7cf8 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -15,8 +15,9 @@ #include "paddle/framework/backward.h" #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" +#include "paddle/operators/type_alias.h" namespace paddle { namespace framework { @@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { } }; -class FcOp : public NetOp { +class FcOp : public ops::NetOp { public: void Init() override { AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, @@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) { auto no_input_gop = f::Backward(*fwd, {"X", "b"}); ASSERT_NE(no_input_gop, nullptr); ASSERT_TRUE(no_input_gop->IsNetOp()); - ASSERT_EQ(0UL, std::static_pointer_cast(no_input_gop)->ops_.size()); + ASSERT_EQ(0UL, + std::static_pointer_cast(no_input_gop)->ops_.size()); } TEST(Backward, net_fc_backward_normal) { @@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) { ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); - auto net = static_cast(gop.get()); + auto net = static_cast(gop.get()); ASSERT_NO_THROW(net->DebugString()); @@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) { ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); - auto net = static_cast(gop.get()); + auto net = static_cast(gop.get()); ASSERT_NO_THROW(net->DebugString()); @@ -228,7 +230,7 @@ TEST(Backward, net_fc_backward_not_have_b) { } TEST(Backward, net_input_of_network_not_need_grad) { - f::NetOp net; + ops::NetOp net; net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, @@ -236,7 +238,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { net.CompleteAddOp(); auto bwd = Backward(net, {"X"}); // X@GRAD is not need. ASSERT_TRUE(bwd->IsNetOp()); - auto bwd_net = static_cast(bwd.get()); + auto bwd_net = static_cast(bwd.get()); std::unordered_set all_output = std::unordered_set( bwd_net->outputs_.begin(), bwd_net->outputs_.end()); @@ -253,7 +255,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); - auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); + auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ( f::OperatorBase::EMPTY_VAR_NAME(), @@ -261,14 +263,14 @@ TEST(Backward, net_input_of_network_not_need_grad) { } TEST(Backward, net_shared_weight) { - f::NetOp net; + ops::NetOp net; net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); net.CompleteAddOp(); auto bwd = f::Backward(net, {}); ASSERT_TRUE(bwd->IsNetOp()); - auto bwd_net = static_cast(bwd.get()); + auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ("add", bwd_net->ops_[2]->type_); } @@ -285,7 +287,7 @@ TEST(Backward, op_all_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"X", "b"}); ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); + auto net = static_cast(backward.get()); ASSERT_TRUE(net->ops_.empty()); } @@ -293,7 +295,7 @@ TEST(Backward, op_all_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"Out"}); ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); + auto net = static_cast(backward.get()); ASSERT_TRUE(net->ops_.empty()); } @@ -301,7 +303,7 @@ TEST(Backward, op_part_of_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); + auto net = static_cast(backward.get()); ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; @@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) { } TEST(Backward, linear_net_intermediate_variable_has_no_grad) { - f::NetOp net; + ops::NetOp net; net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"mul_out1", "add_out1", "out1"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, @@ -351,7 +353,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { net.CompleteAddOp(); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); ASSERT_TRUE(backward->IsNetOp()); - auto bwd_net = static_cast(backward.get()); + auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 3UL); auto &grad_fc = *bwd_net->ops_[0]; EXPECT_EQ(grad_fc.inputs_.size(), diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0..3e72e391266066de9e4114e68b43b066c15254db 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -400,6 +400,14 @@ class GradOpRegisterHelper { return 0; \ } +/** + * Macro to Forbid user register Gradient Operator. + */ +#define NO_GRADIENT(__op_type) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type##__op_type##_grad, \ + "NO_GRADIENT must be in global namespace") + /** * Macro to Register OperatorKernel. */ diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index e3c510b70346a2baf6ccd756eaf689c146efee5f..cb86e6be2be3624bf54ee28193ca5d4c7bafa0eb 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -20,16 +20,16 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* ExecutionContext::GetEigenDevice< +Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return device_context_.get_eigen_device(); + return *device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* +Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return device_context_.get_eigen_device(); + return *device_context_.get_eigen_device(); } #endif @@ -52,7 +52,8 @@ std::vector OperatorBase::Inputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); auto input_format = GetAttr>("input_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= (int)inputs_.size(), + PADDLE_ENFORCE(input_format.at(static_cast(offset) + 1) <= + static_cast(inputs_.size()), "Input Out Of Range"); return std::vector{ @@ -78,7 +79,8 @@ std::vector OperatorBase::Outputs(const std::string& name) const { PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto output_format = GetAttr>("output_format"); auto offset = in_out_idxs_->at(name); - PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(), + PADDLE_ENFORCE(output_format.at(static_cast(offset) + 1) <= + static_cast(outputs_.size()), "Output Out of Range"); return std::vector{ outputs_.begin() + output_format.at(offset), diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6a9fe19b9b61333cf9db1cca3e34c72f3f9c99c5..0b588297169540417586d7c167a1265827b683ac 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -55,6 +55,10 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + static std::string GRAD_VAR_NAME(const std::string& name) { + return name + GRAD_VAR_SUFFIX(); + } + /// Variables with this suffix are supposed to be filled up with zeros. static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } @@ -161,22 +165,30 @@ class OperatorContext { template const T* Input(const size_t index) const { - return &(InputVar(index)->Get()); + auto var = InputVar(index); + PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index); + return &var->Get(); } template T* Output(const size_t index) const { - return OutputVar(index)->GetMutable(); + auto var = OutputVar(index); + PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index); + return var->GetMutable(); } template const T* Input(const std::string& name) const { - return &(InputVar(name)->Get()); + auto var = InputVar(name); + PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name); + return &var->Get(); } template T* Output(const std::string& name) const { - return OutputVar(name)->GetMutable(); + auto var = OutputVar(name); + PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name); + return var->GetMutable(); } template @@ -185,8 +197,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return &scope_.FindVar(name)->Get(); + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiInput(%s:%s) should not be nullptr", + name, sub_name); + return &var->Get(); }); return res; } @@ -197,8 +213,12 @@ class OperatorContext { std::vector res; res.reserve(names.size()); std::transform(names.begin(), names.end(), std::back_inserter(res), - [this](const std::string& name) { - return scope_.FindVar(name)->GetMutable(); + [&](const std::string& sub_name) { + auto var = scope_.FindVar(sub_name); + PADDLE_ENFORCE(var != nullptr, + "MultiOutput(%s:%s) should not be nullptr", + name, sub_name); + return var->GetMutable(); }); return res; } @@ -237,7 +257,7 @@ class ExecutionContext : public OperatorContext { template ::EigenDeviceType> - DeviceType* GetEigenDevice() const; + DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_.GetPlace(); } diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f0f3ef7e3a4230c09ea6f766c4017946ac0b5a --- /dev/null +++ b/paddle/framework/pybind.cc @@ -0,0 +1,229 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/framework/backward.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor_py.h" +#include "paddle/operators/net_op.h" +#include "paddle/operators/type_alias.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +USE_OP(add_two); +USE_OP(onehot_cross_entropy); +USE_OP_WITHOUT_KERNEL(fc); +USE_OP(sgd); +USE_OP(mul); +USE_OP(mean); +USE_OP(sigmoid); +USE_OP(softmax); +USE_OP(rowwise_add); +USE_OP_WITHOUT_KERNEL(recurrent_op); +namespace paddle { +namespace framework { +template +void ExposeOperator(ClassType &m) { + m.def("infer_shape", &ClassType::type::InferShape) + .def("run", &ClassType::type::Run) + .def("type", + [](const typename ClassType::type &op) -> std::string { + return op.type_; + }) + .def("outputs", + [](const typename ClassType::type &op) -> std::vector { + return op.outputs_; + }) + .def("__str__", &ClassType::type::DebugString); +} + +static size_t UniqueIntegerGenerator() { + static std::atomic generator; + return generator.fetch_add(1); +} + +bool IsCompileGPU() { +#ifdef PADDLE_ONLY_CPU + return false; +#else + return true; +#endif +} + +PYBIND11_PLUGIN(core) { + py::module m("core", "C++ core of PaddlePaddle"); + + py::class_(m, "Tensor", py::buffer_protocol()) + .def_buffer( + [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) + .def("get_dims", + [](const Tensor &self) { return vectorize(self.dims()); }) + .def("set_dims", + [](Tensor &self, const std::vector &dim) { + self.Resize(make_ddim(dim)); + }) + .def("alloc_float", + [](Tensor &self, paddle::platform::GPUPlace &place) { + self.mutable_data(place); + }) + .def("alloc_float", + [](Tensor &self, paddle::platform::CPUPlace &place) { + self.mutable_data(place); + }) + .def("alloc_int", + [](Tensor &self, paddle::platform::CPUPlace &place) { + self.mutable_data(place); + }) + .def("alloc_int", + [](Tensor &self, paddle::platform::GPUPlace &place) { + self.mutable_data(place); + }) + .def("set", PyCPUTensorSetFromArray) + .def("set", PyCPUTensorSetFromArray) +#ifndef PADDLE_ONLY_CPU + .def("set", PyCUDATensorSetFromArray) + .def("set", PyCUDATensorSetFromArray) +#endif + .def("shape", [](Tensor &self) { return vectorize(self.dims()); }); + + py::class_(m, "Variable", R"DOC(Variable Class. + +All parameter, weight, gradient are variables in Paddle. +)DOC") + .def("is_int", [](const Variable &var) { return var.IsType(); }) + .def("set_int", + [](Variable &var, int val) -> void { *var.GetMutable() = val; }) + .def("get_int", [](const Variable &var) -> int { return var.Get(); }) + .def("get_tensor", + [](Variable &self) -> Tensor * { return self.GetMutable(); }, + py::return_value_policy::reference) + .def("get_net", + [](Variable &self) -> ops::NetOp * { + return self.GetMutable(); + }, + py::return_value_policy::reference); + + py::class_(m, "Scope", "") + .def("new_var", + [](Scope &self, const std::string &name) -> Variable * { + return self.NewVar(name); + }, + py::return_value_policy::reference) + .def("find_var", &Scope::FindVar, py::return_value_policy::reference) + .def(py::init<>()) + .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, + py::return_value_policy::reference) + .def("drop_kids", &Scope::DropKids); + + //! @note: Be careful! PyBind will return std::string as an unicode, not + //! Python str. If you want a str object, you should cast them in Python. + m.def("get_all_op_protos", []() -> std::vector { + auto &protos = OpRegistry::protos(); + std::vector ret_values; + for (auto it = protos.begin(); it != protos.end(); ++it) { + PADDLE_ENFORCE(it->second.IsInitialized(), + "OpProto must all be initialized"); + std::string str; + PADDLE_ENFORCE(it->second.SerializeToString(&str), + "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.push_back(py::bytes(str)); + } + return ret_values; + }); + m.def_submodule( + "var_names", + "The module will return special predefined variable name in Paddle") + .def("empty", OperatorBase::EMPTY_VAR_NAME) + .def("temp", OperatorBase::TMP_VAR_NAME); + // clang-format off + py::class_(m, "DeviceContext") + .def_static("create", + [](paddle::platform::CPUPlace& place) + -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }) + .def_static("create", + [](paddle::platform::GPUPlace& place) + -> paddle::platform::DeviceContext* { +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("GPUPlace is not supported in CPU device."); +#else + return new paddle::platform::CUDADeviceContext(place); +#endif + }); + // clang-format on + + py::class_(m, "GPUPlace").def(py::init()); + + py::class_(m, "CPUPlace").def(py::init<>()); + + py::class_> operator_base( + m, "Operator"); + + operator_base.def_static("create", [](py::bytes protobin) { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return OpRegistry::CreateOp(desc); + }); + + operator_base.def("backward", + [](const OperatorBase &forwardOp, + const std::unordered_set &no_grad_vars) { + return Backward(forwardOp, no_grad_vars); + }); + + ExposeOperator(operator_base); + + py::class_> net(m, "Net"); + + net.def_static("create", + []() -> std::shared_ptr { + auto retv = std::make_shared(); + retv->type_ = "plain_net"; + return retv; + }) + .def("add_op", &ops::NetOp::AddOp) + .def( + "add_op", + [](ops::NetOp &self, const std::shared_ptr &net) -> void { + self.AddOp(std::static_pointer_cast(net)); + }) + .def("complete_add_op", &ops::NetOp::CompleteAddOp) + .def("complete_add_op", + [](std::shared_ptr &self) { self->CompleteAddOp(); }); + + ExposeOperator(net); + + m.def("unique_integer", UniqueIntegerGenerator); + + m.def("is_compile_gpu", IsCompileGPU); + + return m.ptr(); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 76070f636b0971f4a136042e056c59adb5dc2d40..4c3b14b83d841e88683a13634c93f51c012128b6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -26,19 +26,17 @@ limitations under the License. */ #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { -namespace pybind { -namespace details { // forward declare -template -struct CastToPyBufferImpl; -} // namespace details -} // namespace pybind namespace framework { +namespace details { +template +struct CastToPyBufferImpl; +} class Tensor { public: template - friend struct paddle::pybind::details::CastToPyBufferImpl; + friend struct details::CastToPyBufferImpl; template friend struct EigenTensor; @@ -167,4 +165,4 @@ class Tensor { } // namespace framework } // namespace paddle -#include "paddle/framework/detail/tensor-inl.h" +#include "paddle/framework/tensor_impl.h" diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/tensor_impl.h similarity index 97% rename from paddle/framework/detail/tensor-inl.h rename to paddle/framework/tensor_impl.h index e7ff09dd5c954378afeca299e901277c3ebdb96a..92621f8c18ec0d03160a23c462830d14272c7f64 100644 --- a/paddle/framework/detail/tensor-inl.h +++ b/paddle/framework/tensor_impl.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - #include "paddle/memory/memcpy.h" namespace paddle { @@ -62,9 +61,11 @@ inline T* Tensor::mutable_data(platform::Place place) { if (platform::is_cpu_place(place)) { holder_.reset(new PlaceholderImpl( boost::get(place), size)); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); } -#ifndef PADDLE_ONLY_CPU - else if (platform::is_gpu_place(place)) { +#else holder_.reset(new PlaceholderImpl( boost::get(place), size)); } diff --git a/paddle/pybind/tensor_bind.h b/paddle/framework/tensor_py.h similarity index 64% rename from paddle/pybind/tensor_bind.h rename to paddle/framework/tensor_py.h index 995e102bf9d342e1604f5ae704288d6cf68d97a4..4e1ab77b157fe1adaeac55c271c056236f2d40de 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/framework/tensor_py.h @@ -13,15 +13,17 @@ limitations under the License. */ #pragma once -#include -#include -#include +#include +#include "paddle/framework/tensor.h" +#include "paddle/memory/memcpy.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" namespace py = pybind11; namespace paddle { -namespace pybind { +namespace framework { namespace details { @@ -40,9 +42,6 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; py::buffer_info operator()(framework::Tensor &tensor) { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), - "Only CPU tensor can cast to numpy array"); - if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; @@ -56,14 +55,16 @@ struct CastToPyBufferImpl { strides[i - 1] = sizeof(CUR_TYPE) * prod; prod *= dims_outside[i - 1]; } - + framework::Tensor dst_tensor; + if (paddle::platform::is_gpu_place(tensor.holder_->place())) { + dst_tensor.CopyFrom(tensor, platform::CPUPlace()); + } else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { + dst_tensor = tensor; + } return py::buffer_info( - tensor.mutable_data(tensor.holder_->place()), - sizeof(CUR_TYPE), - py::format_descriptor::format(), - (size_t)framework::arity(tensor.dims()), - dims_outside, - strides); + dst_tensor.mutable_data(dst_tensor.holder_->place()), + sizeof(CUR_TYPE), py::format_descriptor::format(), + (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); @@ -77,9 +78,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { } template -void PyTensorSetFromArray( +void PyCPUTensorSetFromArray( framework::Tensor &self, - py::array_t array) { + py::array_t array, + paddle::platform::CPUPlace &place) { std::vector dims; dims.reserve(array.ndim()); for (size_t i = 0; i < array.ndim(); ++i) { @@ -87,9 +89,28 @@ void PyTensorSetFromArray( } self.Resize(framework::make_ddim(dims)); - auto *dst = self.mutable_data(paddle::platform::CPUPlace()); + auto *dst = self.mutable_data(place); std::memcpy(dst, array.data(), sizeof(T) * array.size()); } +#ifndef PADDLE_ONLY_CPU +template +void PyCUDATensorSetFromArray( + framework::Tensor &self, + py::array_t array, + paddle::platform::GPUPlace &place) { + std::vector dims; + dims.reserve(array.ndim()); + for (size_t i = 0; i < array.ndim(); ++i) { + dims.push_back((int)array.shape()[i]); + } + + self.Resize(framework::make_ddim(dims)); + auto *dst = self.mutable_data(place); + paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(), + cudaMemcpyHostToDevice); +} +#endif + } // namespace pybind } // namespace paddle diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index bb4f48364b9b454af7d37fe4d3c340666e53285c..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -109,6 +109,13 @@ protected: return filter[filter.ndims() - 1]; } + // determine whether im2col needs to be performed + inline bool isNeedIm2col(const TensorShape& filter) const { + return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); + } + std::vector strides_; std::vector paddings_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 9deb2739fcfff935a98a0b5b31b5d11819d81227..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,16 +66,23 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -86,15 +93,18 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; @@ -159,19 +169,27 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Col2ImFunctor col2im; GemmFunctor gemm; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -182,6 +200,11 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; + real scale = 0.0f; + if (!needIm2col) { + colData = inputGrad + g * inputOffset; + scale = 1.0f; + } gemm(CblasTrans, CblasNoTrans, M, @@ -192,17 +215,19 @@ public: M, outputGrad + g * outputOffset, N, - 0.0f, + scale, colData, N); - col2im(inputGrad + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); + if (needIm2col) { + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -255,16 +280,23 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -274,15 +306,18 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; diff --git a/paddle/gserver/layers/ClipLayer.cpp b/paddle/gserver/layers/ClipLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13f16c953793b82183237188b56eb61d76ecd2fd --- /dev/null +++ b/paddle/gserver/layers/ClipLayer.cpp @@ -0,0 +1,79 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer for clipping the input value by the threshold. + * \f[ + * out[i] = \min\left(\max\left(in[i],p_{1}\right),p_{2}\right) + * \f] + */ + +class ClipLayer : public Layer { +protected: + double min_; + double max_; + +public: + explicit ClipLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(clip, ClipLayer); + +bool ClipLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 1U); + auto layerConf = config_.inputs(0).clip_conf(); + min_ = layerConf.min(); + max_ = layerConf.max(); + CHECK_LT(min_, max_); + return true; +} + +void ClipLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + resetOutput(inV->getHeight(), inV->getWidth()); + MatrixPtr outV = getOutputValue(); + outV->copyFrom(*inV); + outV->clip(min_, max_); +} + +void ClipLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + if (inG) { + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + MatrixPtr tmpMtx; + Matrix::resizeOrCreate( + tmpMtx, outG->getHeight(), outG->getWidth(), false, useGpu_); + tmpMtx->clipDerivative(*inV, min_, max_); + inG->addDotMul(*outG, *tmpMtx, 1, 1); + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/RowL2NormLayer.cpp b/paddle/gserver/layers/RowL2NormLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d609be43b73a86d0d0f7b60be993836e2ea6fff --- /dev/null +++ b/paddle/gserver/layers/RowL2NormLayer.cpp @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Layer.h" + +namespace paddle { + +/** + * A layer for L2 normalization in each row, + * \f[ + * out[i] = \frac{in[i]}{\sqrt{\sum_{k=1}^N in[k]^{2}}} + * \f] + * where the size of \f$in\f$ is (batchSize x dataDim), + * and the size of \f$out\f$ is (batchSize x dataDim). + */ + +class RowL2NormLayer : public Layer { +protected: + MatrixPtr inSquare_; + MatrixPtr l2NormReciprocal_; + MatrixPtr dotSum_; + +public: + explicit RowL2NormLayer(const LayerConfig& config) : Layer(config) {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; +}; + +REGISTER_LAYER(row_l2_norm, RowL2NormLayer); + +bool RowL2NormLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + Layer::init(layerMap, parameterMap); + + CHECK_EQ(inputLayers_.size(), 1U); + + return true; +} + +void RowL2NormLayer::forward(PassType passType) { + Layer::forward(passType); + + MatrixPtr inV = getInputValue(0); + + /* malloc memory for the output_ if necessary */ + size_t batchSize = inV->getHeight(); + size_t dataDim = getSize(); + CHECK_EQ(dataDim, inV->getWidth()); + resetOutput(batchSize, dataDim); + MatrixPtr outV = getOutputValue(); + + Matrix::resizeOrCreate(inSquare_, batchSize, dataDim, false, useGpu_); + inV->square2(*inSquare_); + Matrix::resizeOrCreate(l2NormReciprocal_, batchSize, 1, false, useGpu_); + inSquare_->rowSum(*l2NormReciprocal_); + l2NormReciprocal_->sqrt2(*l2NormReciprocal_); + l2NormReciprocal_->scalarDiv(*l2NormReciprocal_, 1.0); + outV->rowScale(0, *inV, *l2NormReciprocal_); +} + +void RowL2NormLayer::backward(const UpdateCallback& callback) { + MatrixPtr inV = getInputValue(0); + MatrixPtr inG = getInputGrad(0); + MatrixPtr outV = getOutputValue(); + MatrixPtr outG = getOutputGrad(); + size_t batchSize = inV->getHeight(); + + // inG[ij] += outG[ij] / l2NormReciprocal + // inG[ij] += -inV[ij] * l2NormReciprocal * l2NormReciprocal * DotMul(outG[i], + // inV[i]) + if (inG) { + Matrix::resizeOrCreate(dotSum_, batchSize, 1, false, useGpu_); + dotSum_->zeroMem(); + dotSum_->rowDotMul(0, *outG, *outV); + dotSum_->dotMul(*dotSum_, *l2NormReciprocal_); + dotSum_->dotMul(*dotSum_, *l2NormReciprocal_); + inSquare_->rowScale(0, *inV, *dotSum_); + inG->sub(*inSquare_); + inG->addRowScale(0, *outG, *l2NormReciprocal_); + } +} + +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 8ce8600c6743779899b2685c1c12053922265411..fe11278f41c0118ee0bdb34f17fbf9602e0fa76b 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1899,6 +1899,36 @@ TEST(Layer, CropLayer) { } } +TEST(Layer, ClipLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("clip"); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ClipConfig* layerConf = input->mutable_clip_conf(); + double p1 = std::rand() / (double)RAND_MAX; + double p2 = std::rand() / (double)RAND_MAX; + layerConf->set_min(std::min(p1, p2)); + layerConf->set_max(std::max(p1, p2)); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "clip", batchSize, false, useGpu, false); + } +} + +TEST(Layer, RowL2NormLayer) { + const size_t batchSize = 128; + const size_t size = 512; + TestConfig config; + config.layerConfig.set_type("row_l2_norm"); + config.layerConfig.set_size(size); + config.inputDefs.push_back({INPUT_DATA, "input", size, 0}); + config.layerConfig.add_inputs(); + for (auto useGpu : {false, true}) { + testLayerGrad(config, "row_l2_norm", batchSize, false, useGpu, false); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index de48b6fac9c7d8125a552022c52353ef6bcef995..ba2b47d6cc6961a380b7db2781b4d214dea829db 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -442,6 +442,13 @@ DEFINE_MATRIX_UNARY_PARAMETER_OP(Clip, TWO_PARAMETER, template void BaseMatrixT::clip(T p1, T p2) { applyUnary(unary::Clip(p1, p2)); } +DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, + a = b < p1 ? 0 : (b > p2 ? 0 : 1)); +template +void BaseMatrixT::clipDerivative(BaseMatrixT& b, T p1, T p2) { + applyBinary(binary::ClipDerivative(p1, p2), b); +} + DEFINE_MATRIX_UNARY_PARAMETER_OP(BiggerThanScalar, ONE_PARAMETER, a = a > p ? 1.0f : 0.0f); template diff --git a/paddle/math/BaseMatrix.h b/paddle/math/BaseMatrix.h index 120d69f718b954925438fbd2119d69f0be13b3e9..12ad2d45a0bbff182e78da6efb3c5ff4c6b59b55 100644 --- a/paddle/math/BaseMatrix.h +++ b/paddle/math/BaseMatrix.h @@ -488,6 +488,13 @@ public: */ void clip(T p1, T p2); + /** + * this = b < low ? 0 : 1 + * + * this = b > high ? 0 : 1 + */ + void clipDerivative(BaseMatrixT& b, T p1, T p2); + /** * @code * a = a > p ? 1.0f : 0.0f diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index b0bb79cbc02122e03e011115ca8fba8967edfb7e..96c76e22e9814682008f2e6c7ae98e2599d391c2 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,22 +41,27 @@ function(op_library TARGET) endif() endfunction() +cc_library(net_op SRCS net_op.cc DEPS op_registry) +cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) + op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) +op_library(mean_op SRCS mean_op.cc mean_op.cu) +cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op) + op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) -op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) + +op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) -op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op - softmax_op net) - op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) -op_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc -tensor op_registry operator net) -cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS -recurrent_network_op gtest mul_op add_op) +op_library(fc_op + SRCS fc_op.cc + DEPS mul_op rowwise_add_op sigmoid_op softmax_op net_op) +op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net_op) +cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 3a43dbfbada87e458109d8ca22effdb4407b4c1d..85269a5f7445a1745d9be68417789e33eb725d5c 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -50,10 +50,6 @@ The equation is: Out = X + Y class AddOpGrad : public OperatorWithKernel { protected: void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "AddOpGrad"; - return ""; - } }; } // namespace operators diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 79d8de6cd46e1c72b14b0554c7be7b4eee281f4c..f961b37565f400b5c26844b9e7a3cff5e682340b 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/add_op.h" diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index d2b649fcbd1e5cac1c8cfcfd4e522e41135f7d1f..54d2231425293f6cfb3adc9cb34d903a75fcdcd0 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,10 +28,13 @@ public: output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = - framework::EigenVector::Flatten(*input0) + - framework::EigenVector::Flatten(*input1); + auto X = EigenVector::Flatten(*input0); + auto Y = EigenVector::Flatten(*input1); + auto Z = EigenVector::Flatten(*output); + + auto place = context.GetEigenDevice(); + + Z.device(place) = X + Y; } }; diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 19e4b74596a0f59edd04db830ec6f6f481373465..926a0c616b957d8e542c1f3dee227a718fb29f07 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..78131b26808b183ee107313374493ae870f1b641 --- /dev/null +++ b/paddle/operators/mean_op.cc @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/mean_op.h" + +namespace paddle { +namespace operators { + +class MeanOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); + PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); + PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr, + "Input/Output of MeanOp must be initialized."); + ctx.Output(0)->Resize(framework::make_ddim({1})); + } +}; + +class MeanOpMaker : public OpProtoAndCheckerMaker { +public: + MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of mean op"); + AddOutput("Out", "The output of mean op").IgnoreGradient(); + AddComment("Mean Operator"); + } +}; + +class MeanGradOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + ctx.Output("X" + GRAD_VAR_SUFFIX()) + ->Resize(ctx.Input("X")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); +REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); +REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); +REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e15de2fd0dd84e4015ee0e3b5343d7651b027a88 --- /dev/null +++ b/paddle/operators/mean_op.cu @@ -0,0 +1,6 @@ +#define EIGEN_USE_GPU + +#include "paddle/operators/mean_op.h" + +REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel); +REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel); \ No newline at end of file diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e712dee6a785749e51be7b233e85dbf39c835218 --- /dev/null +++ b/paddle/operators/mean_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/type_alias.h" + +namespace paddle { +namespace operators { + +template +class MeanKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto input = context.Input(0); + auto output = context.Output(0); + + output->mutable_data(context.GetPlace()); + + auto X = EigenVector::Flatten(*input); + auto y = EigenScalar::From(*output); + auto place = context.GetEigenDevice(); + + y.device(place) = X.mean(); + } +}; + +template +class MeanGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto OG = context.Input("Out" + OperatorBase::GRAD_VAR_SUFFIX()); + PADDLE_ENFORCE(framework::product(OG->dims()) == 1, + "Mean Gradient should be scalar"); + auto IG = context.Output("X" + OperatorBase::GRAD_VAR_SUFFIX()); + IG->mutable_data(context.GetPlace()); + + T ig_size = (T)framework::product(IG->dims()); + + EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = + EigenScalar::From(*OG) / ig_size; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/mean_op_test.cc b/paddle/operators/mean_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..375dcd50e130355c60f82b9d39d1b94fb2c911b0 --- /dev/null +++ b/paddle/operators/mean_op_test.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include + +USE_OP(mean); + +TEST(MeanOp, GetOpProto) { + auto& protos = paddle::framework::OpRegistry::protos(); + auto it = protos.find("mean"); + ASSERT_NE(it, protos.end()); +} diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index c27fc886ce7238a13c8ef86bce673a2b54949a9d..dc9236701627dc9335b844d2a82e18eb1f7dfd42 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); \ No newline at end of file diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index eef72ab293e13a9d05ce0013be41ec4bb75d6077..c7b78ad39045d25d73bfc2c930063c255a514864 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -26,13 +26,18 @@ public: Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; + auto input0 = context.Input("X"); + auto input1 = context.Input("Y"); auto output = context.Output(0); + output->mutable_data(context.GetPlace()); - EigenMatrix::From(*output).device(*(context.GetEigenDevice())) = - EigenMatrix::From(*context.Input("X")) - .contract(EigenMatrix::From(*context.Input("Y")), - dim_pair); + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto Z = EigenMatrix::From(*output); + auto place = context.GetEigenDevice(); + + Z.device(place) = X.contract(Y, dim_pair); } }; } // namespace operators diff --git a/paddle/framework/net.cc b/paddle/operators/net_op.cc similarity index 96% rename from paddle/framework/net.cc rename to paddle/operators/net_op.cc index 2cd378c6b21303d1a24206ba3010b0d035aaa766..fbc98e09923bda7f3baee04e02df9076247bff0b 100644 --- a/paddle/framework/net.cc +++ b/paddle/operators/net_op.cc @@ -14,11 +14,11 @@ limitations under the License. */ -#include "paddle/framework/net.h" +#include "paddle/operators/net_op.h" #include "paddle/framework/op_registry.h" namespace paddle { -namespace framework { +namespace operators { void NetOp::CompleteAddOp(bool calc) { add_op_done_ = true; @@ -74,5 +74,5 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/operators/net_op.h similarity index 89% rename from paddle/framework/net.h rename to paddle/operators/net_op.h index acf1a69da9fd8adce1bd89367c882eade052e725..13611e1ee83170db43e17d6088e4b04588ce6255 100644 --- a/paddle/framework/net.h +++ b/paddle/operators/net_op.h @@ -14,15 +14,17 @@ limitations under the License. */ #pragma once -#include -#include +#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/operators/type_alias.h" #include "paddle/platform/device_context.h" namespace paddle { -namespace framework { +namespace operators { + /** * @brief Network is also a type of Operator * @@ -37,13 +39,13 @@ namespace framework { * This is the base class of network, all the networks should implement the APIs * it defines. */ -class NetOp : public OperatorBase { - public: +class NetOp : public framework::OperatorBase { +public: /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - void InferShape(const Scope& scope) const override { + void InferShape(const framework::Scope& scope) const override { for (auto& op : ops_) { op->InferShape(scope); } @@ -56,7 +58,7 @@ class NetOp : public OperatorBase { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - void Run(const Scope& scope, + void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const override { for (auto& op : ops_) { op->Run(scope, dev_ctx); @@ -88,7 +90,7 @@ class NetOp : public OperatorBase { std::vector> ops_; - private: +private: bool add_op_done_{false}; template @@ -97,5 +99,5 @@ class NetOp : public OperatorBase { } }; -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/framework/net_design.md b/paddle/operators/net_op_design.md similarity index 100% rename from paddle/framework/net_design.md rename to paddle/operators/net_op_design.md diff --git a/paddle/framework/net_op_test.cc b/paddle/operators/net_op_test.cc similarity index 91% rename from paddle/framework/net_op_test.cc rename to paddle/operators/net_op_test.cc index f32e456e5d142bf8203f9ec03e8059772c4f5c99..18c5c60eb43250c23e2819a3c79ab8a96fec103e 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -1,16 +1,18 @@ +#include "paddle/operators/net_op.h" + #include -#include -#include -#include + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { -namespace framework { +namespace operators { static int infer_shape_cnt = 0; static int run_cnt = 0; class TestOp : public OperatorBase { - public: +public: void InferShape(const framework::Scope& scope) const override { ++infer_shape_cnt; } @@ -21,7 +23,7 @@ class TestOp : public OperatorBase { }; class EmptyOp : public OperatorBase { - public: +public: void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} @@ -73,7 +75,7 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); } -TEST(Net, insert_op) { +TEST(NetOp, insert_op) { NetOp net; auto op1 = std::make_shared(); op1->inputs_ = {"x", "w1", "b1"}; @@ -85,5 +87,5 @@ TEST(Net, insert_op) { ASSERT_EQ(3UL, net.ops_.size()); } -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_op.cc similarity index 67% rename from paddle/operators/recurrent_network_op.cc rename to paddle/operators/recurrent_op.cc index 60d065fc4789f76370840328870165579aa73b67..aeb95569b728f53b288a0c9a28220be8b5f7aaa4 100644 --- a/paddle/operators/recurrent_network_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -12,14 +12,14 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/recurrent_network_op.h" +#include "paddle/operators/recurrent_op.h" #include #include #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -29,11 +29,15 @@ namespace rnn { void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, - const size_t seq_len) { + const size_t seq_len, + bool infer_shape_mode) { PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); for (size_t i = 0; i < inlinks.size(); ++i) { - Tensor* input = - step_scopes[0]->FindVar(inlinks[i].external)->GetMutable(); + auto input_var = step_scopes[0]->FindVar(inlinks[i].external); + PADDLE_ENFORCE(input_var != nullptr, + "input link [%s] is not in scope.", + inlinks[i].external); + Tensor* input = input_var->GetMutable(); DDim dims = input->dims(); PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, "all the inlinks must have same length"); @@ -41,7 +45,9 @@ void SegmentInputs(const std::vector& step_scopes, for (size_t j = 0; j < seq_len; j++) { Tensor* step_input = step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable(); - *step_input = input->Slice(j, j + 1); + if (!infer_shape_mode) { + *step_input = input->Slice(j, j + 1); + } step_input->Resize(step_dims); } } @@ -49,36 +55,41 @@ void SegmentInputs(const std::vector& step_scopes, void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len) { + const size_t seq_len, + bool infer_shape_mode) { for (size_t i = 0; i < outlinks.size(); i++) { - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); - - // TODO(qingiqng) remove following code after adding - // InferShape in RecurrentGradientOp - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); - std::vector dims_vec = vectorize(step_dims); - dims_vec.insert(dims_vec.begin(), seq_len); - output->mutable_data(make_ddim(dims_vec), platform::CPUPlace()); - - for (size_t j = 0; j < seq_len; j++) { - Tensor* step_output = - step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); - // TODO(luotao02) data type and platform::DeviceContext() should set - // correctly - (output->Slice(j, j + 1)) - .CopyFrom(*step_output, platform::CPUPlace()); + auto output_var = step_scopes[0]->FindVar(outlinks[i].external); + PADDLE_ENFORCE(output_var != nullptr, + "output link [%s] is not in scope.", + outlinks[i].external); + Tensor* output = output_var->GetMutable(); + if (infer_shape_mode) { + DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + dims_vec.insert(dims_vec.begin(), seq_len); + output->Resize(make_ddim(dims_vec)); + } else { + output->mutable_data(platform::CPUPlace()); + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_output = + step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); + // TODO(luotao02) data type and platform::DeviceContext() should set + // correctly + (output->Slice(j, j + 1)) + .CopyFrom(*step_output, platform::CPUPlace()); + } } } } void LinkMemories(const std::vector& scopes, const std::vector& memories, - size_t step_id, - int offset) { + const size_t step_id, + const int offset, + bool infer_shape_mode) { PADDLE_ENFORCE(step_id < scopes.size(), "step [%d] is out of range of step scopes' size [%d]", step_id, @@ -95,18 +106,13 @@ void LinkMemories(const std::vector& scopes, auto scope = scopes[step_id]; auto linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { - auto mem = scope->NewVar(attr.pre_var)->GetMutable(); - // maybe share variable is better? + auto mem = scope->FindVar(attr.pre_var)->GetMutable(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); - mem->ShareDataWith(*linked_mem); - - // TODO(qingqing) remove following code - // the memory of current step should be allocated in step net - auto m = scope->NewVar(attr.var)->GetMutable(); - // for unit test, as addOp and mulOp are null currently, if not - // mutable_data, mem.data() in output will be error. We will - // remove this line after merge the correct addOp and mulOp. - m->mutable_data(mem->dims(), platform::CPUPlace()); + if (infer_shape_mode) { + mem->Resize(linked_mem->dims()); + } else { + mem->ShareDataWith(*linked_mem); + } } } @@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ->dims()[0]; CreateScopes(scope); auto step_scopes = GetStepScopes(scope); - - // SegmentInputs is called in InferShape. The input must hold memory in - // SegmentInputs. But the other op only set dimension for the output in - // InferShape. That's a problem. Wether the RNN op needs InferShape or not? - // Wether the following functions (SegmentInputs, InitMemories, ...) need - // to rewrite for RNN op? - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - - InitMemories(step_scopes[0]); - - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "stepnet [%s] is not in scope.", - arg_->step_net); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); + InitMemories(step_scopes[0], true /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); - // If the InferShape is called in OperatorBase's run function, - // the rnn op only needs to do InferShape for the first time step for (size_t i = 0; i < seq_len_; i++) { if (i > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, i, -1); + rnn::LinkMemories( + step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } net->GetMutable()->InferShape(*step_scopes[i]); } - - auto outlinks = arg_->outlinks; - for (size_t i = 0; i < outlinks.size(); i++) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); - std::vector dims_vec = vectorize(step_dims); - // now only support fixed length - dims_vec.insert(dims_vec.begin(), seq_len_); - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); - output->Resize(make_ddim(dims_vec)); - } + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); } void RecurrentAlgorithm::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); - + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); + InitMemories(step_scopes[0], false /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); + for (size_t step_id = 0; step_id < seq_len_; step_id++) { - // the link memory is done in InferShape - // maybe remove following code after testing if (step_id > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); } - - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); } void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { @@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { // Now all variables in scope must be created outside of op. auto net_op = scope.FindVar(arg_->step_net)->GetMutable(); for (auto& input : net_op->inputs_) { + // the weight are located in parent scope if (!step_scope.FindVar(input)) step_scope.NewVar(input); } for (auto& output : net_op->outputs_) { step_scope.NewVar(output); } - step_scopes->emplace_back(&step_scope); } } } -void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { +void RecurrentAlgorithm::InitMemories(Scope* step_scope, + bool infer_shape_mode) const { for (auto& attr : arg_->memories) { Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable(); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, @@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { attr.var, attr.boot_var); Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable(); - pre_mem->ShareDataWith(*boot_mem); - - // TODO(qingqing) remove following code - // the memory of current step should be allocated in step net - // here for unit test - auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable(); - cur_step_mem->mutable_data(boot_mem->dims(), platform::CPUPlace()); + if (infer_shape_mode) { + pre_mem->Resize(boot_mem->dims()); + } else { + pre_mem->ShareDataWith(*boot_mem); + } } } @@ -307,13 +291,14 @@ public: : OpProtoAndCheckerMaker(proto, op_checker) { const auto& name = RecurrentOp::kArgName; // inputs and outputs stored in proto - AddInput(name.inlinks, "the input that need to be segmented for each step.") + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") .SetMultiple(); AddInput(name.boot_memories, "variables to initialize memories.") .SetMultiple(); AddInput(name.step_net, "network shared by all steps."); - AddOutput(name.outlinks, "the output that need to concated for all steps.") + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") .SetMultiple(); AddOutput(name.step_scopes, "step scopes"); @@ -331,34 +316,39 @@ public: void RecurrentGradientAlgorithm::Run( const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "step net is not in scope."); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); } - LinkBootMemoryGradients(step_scopes[0]); - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + LinkBootMemoryGradients(step_scopes[0], false); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); } void RecurrentGradientAlgorithm::LinkBootMemoryGradients( - Scope* step_scope) const { + Scope* step_scope, bool infer_shape_mode) const { for (auto& attr : arg_->memories) { - Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); - PADDLE_ENFORCE(mem_grad != nullptr, - "boot_tensor should be retrieved before"); + PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr, + "memory variable [%s] does not exists", + attr.var); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, - "memory [%s]'s boot variable [%s] not exists", - attr.var, + "boot variable [%s] does not exists", attr.boot_var); + Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); Tensor* boot_mem_grad = step_scope->NewVar(attr.boot_var)->GetMutable(); - boot_mem_grad->ShareDataWith(*mem_grad); + if (infer_shape_mode) { + boot_mem_grad->Resize(mem_grad->dims()); + } else { + boot_mem_grad->ShareDataWith(*mem_grad); + } } } @@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ->GetMutable() ->dims()[0]; auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "step net is not in scope."); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); - for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } net->GetMutable()->InferShape(*step_scopes[step_id]); } - - auto outlinks = arg_->outlinks; - for (size_t i = 0; i < outlinks.size(); i++) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); - std::vector dims_vec = vectorize(step_dims); - // now only support fixed length - dims_vec.insert(dims_vec.begin(), seq_len_); - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); - output->Resize(make_ddim(dims_vec)); - } - LinkBootMemoryGradients(step_scopes[0]); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); + LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); } void RecurrentGradientOp::Init() { diff --git a/paddle/operators/recurrent_network_op.h b/paddle/operators/recurrent_op.h similarity index 92% rename from paddle/operators/recurrent_network_op.h rename to paddle/operators/recurrent_op.h index d57a1a2e51cbed22549ab6ebce79223e2d4e3bcf..2a0964fff326500b6215dd4afac63c75d64c4a06 100644 --- a/paddle/operators/recurrent_network_op.h +++ b/paddle/operators/recurrent_op.h @@ -72,19 +72,22 @@ struct ArgumentName { */ void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, - const size_t seq_len); + const size_t seq_len, + bool infer_shape_mode); /** * Process outputs of step nets and merge to variables. */ void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len); + const size_t seq_len, + bool infer_shape_mode); void LinkMemories(const std::vector& step_scopes, const std::vector& memories, - size_t step_id, - int offset); + const size_t step_id, + const int offset, + bool infer_shape_mode); void InitArgument(const ArgumentName& name, Argument* arg); @@ -122,7 +125,7 @@ protected: return *scope.FindVar(arg_->step_scopes)->GetMutable>(); } - void InitMemories(Scope* step_scopes) const; + void InitMemories(Scope* step_scopes, bool infer_shape_mode) const; private: std::unique_ptr arg_; @@ -145,7 +148,7 @@ public: void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; - void LinkBootMemoryGradients(Scope* step_scopes) const; + void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const; /** * InferShape must be called before Run. diff --git a/paddle/operators/recurrent_network_op_test.cc b/paddle/operators/recurrent_op_test.cc similarity index 90% rename from paddle/operators/recurrent_network_op_test.cc rename to paddle/operators/recurrent_op_test.cc index b0e61fbee611744adb85b498b1c3540f059afc8c..08a6d9fe5681fdea180de2e9361734ade8564775 100644 --- a/paddle/operators/recurrent_network_op_test.cc +++ b/paddle/operators/recurrent_op_test.cc @@ -11,14 +11,15 @@ limitations under the License. */ +#include "paddle/operators/recurrent_op.h" + #include #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" -#include "paddle/operators/recurrent_network_op.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace operators { @@ -55,7 +56,7 @@ protected: w->GetMutable()->mutable_data( make_ddim(std::vector{30, 30}), platform::CPUPlace()); - for (auto boot : std::vector{"x_boot", "h_boot"}) { + for (auto boot : std::vector{"h_boot"}) { LOG(INFO) << "create global variable " << boot; Variable* h_boot = scope_.NewVar(boot); h_boot->GetMutable()->mutable_data( @@ -79,7 +80,6 @@ protected: op_desc.add_inputs("x0"); op_desc.add_inputs("x1"); // boot_memories 3 - op_desc.add_inputs("x_boot"); op_desc.add_inputs("h_boot"); // step net 5 op_desc.add_inputs("step_net"); @@ -91,7 +91,7 @@ protected: auto _input_format = std::vector{ 0, // in_link 3, // memories - 5 // step_net + 4 // step_net }; auto input_format = op_desc.add_attrs(); input_format->set_name("input_format"); @@ -129,12 +129,11 @@ protected: inlink_alias->add_strings(item); } // pre memories - for (const auto& item : - std::vector{"rnn/x@pre", "rnn/h@pre"}) { + for (const auto& item : std::vector{"rnn/h@pre"}) { pre_memories->add_strings(item); } // memories - for (const auto& item : std::vector{"rnn/x", "rnn/h"}) { + for (const auto& item : std::vector{"rnn/h"}) { memories->add_strings(item); } // output alias @@ -151,14 +150,11 @@ protected: LOG(INFO) << "create variable step_net"; Variable* var = scope_.NewVar("step_net"); auto net = var->GetMutable(); - // rnn/s is net's input or output? - net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"}; - net->inputs_ = {"rnn/s", "rnn/h"}; net->AddOp( OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); net->AddOp( - OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); + OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {})); net->CompleteAddOp(); } @@ -297,7 +293,10 @@ protected: inlink.internal = "rnn/x"; auto step_scopes = scope_.FindVar("step_scopes")->GetMutable>(); - rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10); + rnn::SegmentInputs(*step_scopes, + std::vector{inlink}, + 10, + true /*infer_shape_mode*/); } void LinkeMemories() { @@ -311,7 +310,8 @@ protected: auto step_scopes = scope_.FindVar("step_scopes")->GetMutable>(); for (int i = 1; i < 10; ++i) { - rnn::LinkMemories(*step_scopes, memories, i, -1); + rnn::LinkMemories( + *step_scopes, memories, i, -1, true /*infer_shape_mode*/); } } @@ -333,14 +333,14 @@ TEST(RecurrentOp, LinkMemories) { using namespace paddle::operators; // create and init step scopes - int len = 10; + size_t len = 10; std::vector step_scopes; - for (int i = 0; i < len; ++i) { + for (size_t i = 0; i < len; ++i) { auto scope = new Scope(); scope->NewVar("pre_h"); auto tensor = scope->NewVar("h")->GetMutable(); float* data = tensor->mutable_data({15, 20}, CPUPlace()); - for (int j = 0; j < 15 * 20; ++j) { + for (size_t j = 0; j < 15 * 20; ++j) { data[j] = rand() * (1. / (double)RAND_MAX); } step_scopes.push_back(scope); @@ -354,24 +354,24 @@ TEST(RecurrentOp, LinkMemories) { std::vector memories; memories.push_back(mem_attr); - for (int i = 1; i < len; ++i) { - rnn::LinkMemories(step_scopes, memories, i, -1); + for (size_t i = 1; i < len; ++i) { + rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/); } // check - for (int i = 0; i < len - 1; ++i) { + for (size_t i = 0; i < len - 1; ++i) { const float* a = step_scopes[i]->FindVar("h")->GetMutable()->data(); const float* b = step_scopes[i + 1] ->FindVar("pre_h") ->GetMutable() ->data(); - for (size_t i = 0; i < 15 * 20; ++i) { - ASSERT_FLOAT_EQ(a[i], b[i]); + for (size_t j = 0; j < 15 * 20; ++j) { + ASSERT_FLOAT_EQ(a[j], b[j]); } } for (int i = len - 2; i >= 0; --i) { - rnn::LinkMemories(step_scopes, memories, i, 1); + rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/); } // check for (int i = len - 2; i >= 0; --i) { @@ -379,8 +379,8 @@ TEST(RecurrentOp, LinkMemories) { step_scopes[i]->FindVar("pre_h")->GetMutable()->data(); const float* b = step_scopes[i + 1]->FindVar("h")->GetMutable()->data(); - for (size_t i = 0; i < 15 * 20; ++i) { - ASSERT_FLOAT_EQ(a[i], b[i]); + for (size_t j = 0; j < 15 * 20; ++j) { + ASSERT_FLOAT_EQ(a[j], b[j]); } } @@ -391,9 +391,3 @@ TEST(RecurrentOp, LinkMemories) { USE_OP(add_two); USE_OP(mul); - -// int main() { -// //! TODO(yuyang18): Temporary disable this unit-test because implementation -// //! error. -// return 0; -//} \ No newline at end of file diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 4b33e38ebabe853e179fe70ef7fde0a80b9050e2..82338ceccc06653791b26472e18d804f62735649 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/rowwise_add_op.h" REGISTER_OP_GPU_KERNEL(rowwise_add, diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index b86dd5463436bf521f9939b1c421b39f11102769..bd4d1128955fb718d3a84dfd96d8c68d7196e9cc 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,7 +33,7 @@ public: const int rest_size = input.size() / bias_size; Eigen::DSizes one_d(input.size()); Eigen::DSizes bcast(rest_size); - output.reshape(one_d).device(*(context.GetEigenDevice())) = + output.reshape(one_d).device(context.GetEigenDevice()) = input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d); } }; diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index f8f5b90cab460b4457cfb0a88bfc012bafe0fbc2..d79258cbf13c699cfb2afaee229cf96a3e377b5e 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel); \ No newline at end of file diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index af1dfdd756ceb9991bee6b85c3281c05f0fb5a9f..0c3a240f9a4a5fc7bc4898e82786810cee2f7010 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -29,8 +29,12 @@ public: param_out->mutable_data(ctx.GetPlace()); - EigenVector::Flatten(*param_out).device(*(ctx.GetEigenDevice())) = - EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); + auto p = EigenVector::Flatten(*param); + auto g = EigenVector::Flatten(*grad); + auto o = EigenVector::Flatten(*param_out); + auto place = ctx.GetEigenDevice(); + + o.device(place) = p - lr * g; } }; diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index f679b20418f04eff4310efe4e121963ce5a235e0..c9d11a2e1f9dcc563765c9e8cc1bae6beff57f18 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,4 @@ +#define EIGEN_USE_GPU #include "paddle/operators/sigmoid_op.h" REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 3dd23a9ebc7ac0972d6ee07b9ac051d59e66f62f..1412e4398440c8e946d3ab434a50e978079637ab 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,9 +27,11 @@ public: auto output = context.Output(0); output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device( - *(context.GetEigenDevice())) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); + auto X = EigenVector::Flatten(*input); + auto Y = EigenVector::Flatten(*output); + auto place = context.GetEigenDevice(); + + Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 5b59fad7d5f9729b0862f8cd78cb32f94f87f513..5cbb96ab754467ea6ddab9380ca25987c9376980 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -1,16 +1,17 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ #include "paddle/operators/softmax_op.h" namespace paddle { @@ -19,12 +20,13 @@ namespace operators { class SoftmaxOp : public OperatorWithKernel { protected: void InferShape(const InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax"); - PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, + PADDLE_ENFORCE(ctx.InputSize() == 1UL, + "Only one input is need for softmax"); + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, "The input of softmax op must be matrix"); - PADDLE_ENFORCE(ctx.OutputSize() == 1, + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, "Only one output is need for softmax"); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -40,10 +42,19 @@ public: class SoftmaxOpGrad : public OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "SoftmaxOpGrad"; - return ""; + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 3UL, + "Input of SoftmaxOpGrad should be 3, X, Y, YG"); + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, + "Output of SoftmaxOpGrad should be 1"); + PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr, + "Input(Y@GRAD) should not be null"); + PADDLE_ENFORCE(ctx.Input("Y")->dims() == + ctx.Input(GRAD_VAR_NAME("Y"))->dims(), + "the shape of Input(0) and Input(1) should be the same"); + ctx.Output(GRAD_VAR_NAME("X")) + ->Resize(ctx.Input("Y")->dims()); } }; @@ -51,5 +62,7 @@ protected: } // namespace paddle REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); -REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); +REGISTER_OP_CPU_KERNEL(softmax_grad, + ops::SoftmaxGradKernel); diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index a1f6944a369fe5148ffcfeabf3bf7063dcbc2664..8c652213f2e4c0e0ea1a31987fcb37c86374cd2a 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -1,4 +1,6 @@ +#define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/softmax_op.h" REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_OP_GPU_KERNEL(softmax_grad, ops::SoftmaxGradKernel); diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index a5c19c5fc7c6f5909dbb355aff09bf15405b6957..13e74a79077982e9fba5d90f40986e699c1ed897 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -1,19 +1,22 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once +#include "paddle/framework/ddim.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" #include "paddle/operators/type_alias.h" namespace paddle { @@ -23,8 +26,8 @@ template class SoftmaxKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - auto input = context.Input(0); - auto output = context.Output(0); + auto input = context.Input("X"); + auto output = context.Output("Y"); output->mutable_data(context.GetPlace()); auto logits = EigenMatrix::From(*input); @@ -46,9 +49,9 @@ public: .reshape(batch_by_one) .broadcast(one_by_class)); - softmax.device(*(context.GetEigenDevice())) = shifted_logits.exp(); + softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(*(context.GetEigenDevice())) = + softmax.device(context.GetEigenDevice()) = (softmax * softmax.sum(along_class) .inverse() @@ -57,5 +60,38 @@ public: .broadcast(one_by_class)); } }; + +template +class SoftmaxGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + std::shared_ptr scale_ = std::make_shared(); + + auto Y = context.Input("Y"); + auto dY = context.Input(OperatorBase::GRAD_VAR_NAME("Y")); + auto dX = context.Output(OperatorBase::GRAD_VAR_NAME("X")); + dX->mutable_data(context.GetPlace()); + + const int batch_size = Y->dims()[0]; + const int class_num = Y->dims()[1]; + + Eigen::DSizes along_class(1); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, class_num); + + auto Y_eigen = EigenMatrix::From(*Y); + auto dY_eigen = EigenMatrix::From(*dY); + auto dX_eigen = EigenMatrix::From(*dX); + auto place = context.GetEigenDevice(); + + auto dot = (Y_eigen * dY_eigen) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 93b62cddc819e0d1fd48323e474a294ff0d327e1..931740e150946a939b8656be5a30185c6ee1cb8f 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -15,13 +15,14 @@ #pragma once #include "paddle/framework/eigen.h" -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace operators { using OpKernel = framework::OpKernel; +using OperatorBase = framework::OperatorBase; using InferShapeContext = framework::InferShapeContext; using ExecutionContext = framework::ExecutionContext; using Variable = framework::Variable; @@ -43,14 +44,16 @@ template using EigenTensor = framework::EigenTensor; using Tensor = framework::Tensor; +using Scope = framework::Scope; using OperatorWithKernel = framework::OperatorWithKernel; +using OperatorBase = framework::OperatorBase; using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker; using OpProto = framework::OpProto; using OpAttrChecker = framework::OpAttrChecker; using CPUPlace = platform::CPUPlace; using GPUPlace = platform::GPUPlace; -using NetOp = framework::NetOp; using OpRegistry = framework::OpRegistry; + } // namespace operators } // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 26c8eb78e614a68ec9728aad727d8fe3e08547ae..60a42c777d1c2ebbc22fdb77b1100cc6fcf7ff35 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -144,12 +144,12 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::platform::EnforceNotMet( \ - std::make_exception_ptr( \ - std::runtime_error(string::Sprintf(__VA_ARGS__))), \ - __FILE__, __LINE__); \ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::platform::EnforceNotMet( \ + std::make_exception_ptr( \ + std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ + __FILE__, __LINE__); \ } while (0) #define PADDLE_ENFORCE(...) \ diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 9bf2d6f72ea9c25919d91fe450c463bdd80c5e6d..8e6b258e00c0012876cda8ffc5b340322d51e894 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,10 @@ -cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op fc_op sgd_op cross_entropy_op recurrent_network_op fill_zeros_like_op) +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python backward + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_op + fill_zeros_like_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc deleted file mode 100644 index f2e9aa6b5d693f6b3706cd9a16fdfea1fec5ebea..0000000000000000000000000000000000000000 --- a/paddle/pybind/pybind.cc +++ /dev/null @@ -1,180 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/framework/net.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/scope.h" -#include "paddle/pybind/tensor_bind.h" -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace py = pybind11; -namespace pd = paddle::framework; - -USE_OP(add_two); -USE_OP(onehot_cross_entropy); -USE_OP_WITHOUT_KERNEL(fc); -USE_OP(sgd); -USE_OP(mul); -USE_OP(sigmoid); -USE_OP(softmax); -USE_OP(rowwise_add); -USE_OP_WITHOUT_KERNEL(recurrent_op); -USE_OP(fill_zeros_like); - -template -void ExposeOperator(ClassType& m) { - m.def("infer_shape", &ClassType::type::InferShape) - .def("run", &ClassType::type::Run) - .def("outputs", - [](const typename ClassType::type& op) -> std::vector { - return op.outputs_; - }) - .def("__str__", &ClassType::type::DebugString); -} - -static size_t UniqueIntegerGenerator() { - static std::atomic generator; - return generator.fetch_add(1); -} - -PYBIND11_PLUGIN(core) { - py::module m("core", "C++ core of PaddlePaddle"); - - py::class_(m, "Tensor", py::buffer_protocol()) - .def_buffer([](pd::Tensor& self) -> py::buffer_info { - return paddle::pybind::CastToPyBuffer(self); - }) - .def("get_dims", - [](const pd::Tensor& self) { return pd::vectorize(self.dims()); }) - .def("set_dims", - [](pd::Tensor& self, const std::vector& dim) { - self.Resize(pd::make_ddim(dim)); - }) - .def("alloc_float", - [](pd::Tensor& self) { - self.mutable_data(paddle::platform::CPUPlace()); - }) - .def("alloc_int", - [](pd::Tensor& self) { - self.mutable_data(paddle::platform::CPUPlace()); - }) - .def("set", paddle::pybind::PyTensorSetFromArray) - .def("set", paddle::pybind::PyTensorSetFromArray) - .def("shape", - [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); - - py::class_(m, "Variable", R"DOC(Variable Class. - -All parameter, weight, gradient are variables in Paddle. -)DOC") - .def("is_int", [](const pd::Variable& var) { return var.IsType(); }) - .def("set_int", - [](pd::Variable& var, int val) -> void { - *var.GetMutable() = val; - }) - .def("get_int", - [](const pd::Variable& var) -> int { return var.Get(); }) - .def("get_tensor", - [](pd::Variable& self) -> pd::Tensor* { - return self.GetMutable(); - }, - py::return_value_policy::reference) - .def("get_net", - [](pd::Variable& self) -> pd::NetOp* { - return self.GetMutable(); - }, - py::return_value_policy::reference); - - py::class_(m, "Scope", "") - .def("new_var", - [](pd::Scope& self, const std::string& name) -> pd::Variable* { - return self.NewVar(name); - }, - py::return_value_policy::reference) - .def("find_var", &pd::Scope::FindVar, py::return_value_policy::reference) - .def(py::init<>()) - .def("new_scope", - [](pd::Scope& self) -> pd::Scope* { return &self.NewScope(); }, - py::return_value_policy::reference) - .def("drop_kids", &pd::Scope::DropKids); - - //! @note: Be careful! PyBind will return std::string as an unicode, not - //! Python str. If you want a str object, you should cast them in Python. - m.def("get_all_op_protos", []() -> std::vector { - auto& protos = pd::OpRegistry::protos(); - std::vector ret_values; - for (auto it = protos.begin(); it != protos.end(); ++it) { - PADDLE_ENFORCE(it->second.IsInitialized(), - "OpProto must all be initialized"); - std::string str; - PADDLE_ENFORCE(it->second.SerializeToString(&str), - "Serialize OpProto Error. This could be a bug of Paddle."); - ret_values.push_back(py::bytes(str)); - } - return ret_values; - }); - m.def_submodule( - "var_names", - "The module will return special predefined variable name in Paddle") - .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) - .def("temp", pd::OperatorBase::TMP_VAR_NAME); - - py::class_(m, "DeviceContext") - .def_static("cpu_context", []() -> paddle::platform::DeviceContext* { - return new paddle::platform::CPUDeviceContext(); - }); - - py::class_> operator_base( - m, "Operator"); - - operator_base.def_static("create", [](py::bytes protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); - }); - ExposeOperator(operator_base); - - py::class_> net(m, "Net"); - - net.def_static("create", - []() -> std::shared_ptr { - auto retv = std::make_shared(); - retv->type_ = "plain_net"; - return retv; - }) - .def("add_op", &pd::NetOp::AddOp) - .def("add_op", - [](pd::NetOp& self, const std::shared_ptr& net) -> void { - self.AddOp(std::static_pointer_cast(net)); - }) - .def("complete_add_op", &pd::NetOp::CompleteAddOp) - .def("complete_add_op", - [](std::shared_ptr& self) { self->CompleteAddOp(); }); - ExposeOperator(net); - - m.def("unique_integer", UniqueIntegerGenerator); - - return m.ptr(); -} diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 3860facb099950a5287d3f6b89c3de38f588f568..8de0e608c1f482e4553c07ff7ffd572d65a772aa 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -69,7 +69,7 @@ cat <> /paddle/build/Dockerfile <