diff --git a/README.md b/README.md index 2a6beeb342b34f8e91ef509d7d41f286a666480c..b9793c3eab5d40c28f01cc67ad607b97261b3235 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ We provide [English](http://doc.paddlepaddle.org/develop/doc/) and - [Deep Learning 101](http://book.paddlepaddle.org/index.html) - You might want to start from the this online interactive book that can run in Jupyter Notebook. + You might want to start from this online interactive book that can run in Jupyter Notebook. - [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html) diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index eff15de73f23db6dea3a7b79006bfec90d712ae5..25c6b4ef52d3f8ebff1572ae8d348be7c577c08c 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -20,34 +20,30 @@ INCLUDE(ExternalProject) SET(MKLDNN_PROJECT "extern_mkldnn") SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn) -SET(MKLDNN_INSTALL_ROOT ${CMAKE_INSTALL_PREFIX}) -IF(NOT "$ENV{HOME}" STREQUAL "/root") - SET(MKLDNN_INSTALL_ROOT "$ENV{HOME}") -ENDIF() - -SET(MKLDNN_INSTALL_DIR "${MKLDNN_INSTALL_ROOT}/opt/paddle/third_party/mkldnn") -SET(MKLDNN_INCLUDE_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) +SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) +SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) -IF(WIN32) - MESSAGE(WARNING "It is not supported compiling with mkldnn in windows Paddle yet." - "Force WITH_MKLDNN=OFF") - SET(WITH_MKLDNN OFF) +IF(WIN32 OR APPLE) + MESSAGE(WARNING + "Windows or Mac is not supported with MKLDNN in Paddle yet." + "Force WITH_MKLDNN=OFF") + SET(WITH_MKLDNN OFF CACHE STRING "Disable MKLDNN in Windows and MacOS" FORCE) return() -ELSE(WIN32) - SET(MKLDNN_LIBRARY "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE) - MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") - SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) - #SET(CMAKE_MACOSX_RPATH 1) # hold for MacOS - SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") -ENDIF(WIN32) +ENDIF() + +SET(MKLDNN_LIB "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE) +MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") +SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") -INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIR}) +INCLUDE_DIRECTORIES(${MKLDNN_INC_DIR}) IF(${CBLAS_PROVIDER} STREQUAL "MKLML") SET(MKLDNN_DEPENDS ${MKLML_PROJECT}) SET(MKLDNN_MKLROOT ${MKLML_ROOT}) SET(MKLDNN_IOMP_LIB ${MKLML_IOMP_LIB}) SET(MKLDNN_IOMP_DIR ${MKLML_LIB_DIR}) + MESSAGE(STATUS "Build MKLDNN with ${MKLDNN_MKLROOT}") ENDIF() ExternalProject_Add( @@ -57,16 +53,15 @@ ExternalProject_Add( GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_TAG "v0.9" PREFIX ${MKLDNN_SOURCES_DIR} - CONFIGURE_COMMAND mkdir -p /build - BUILD_COMMAND cd /build - && cmake .. -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} -DMKLROOT=${MKLDNN_MKLROOT} - && $(MAKE) - INSTALL_COMMAND cd /build && $(MAKE) install UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} + CMAKE_ARGS -DMKLROOT=${MKLDNN_MKLROOT} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLDNN_INSTALL_DIR} + -DMKLROOT:PATH=${MKLDNN_MKLROOT} ) ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL) -SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIBRARY}) +SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIB}) ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) -MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIBRARY}") +MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIB}") LIST(APPEND external_project_dependencies mkldnn) diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index 3f940756a4abb79aba7d3561db19db8532a0b673..17a1ca4ed04dce85ae3c7fdd5f22d6eeed03db59 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -16,19 +16,23 @@ IF(NOT ${WITH_MKLML}) return() ENDIF(NOT ${WITH_MKLML}) +IF(WIN32 OR APPLE) + MESSAGE(WARNING + "Windows or Mac is not supported with MKLML in Paddle yet." + "Force WITH_MKLML=OFF") + SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in Windows and MacOS" FORCE) + return() +ENDIF() + INCLUDE(ExternalProject) SET(MKLML_PROJECT "extern_mklml") -SET(MKLML_VER "mklml_lnx_2018.0.20170425") +SET(MKLML_VER "mklml_lnx_2018.0.20170720") SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz") SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml") SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") -SET(MKLML_DST_DIR "opt/paddle/third_party/mklml") -SET(MKLML_INSTALL_ROOT "${CMAKE_INSTALL_PREFIX}") -IF(NOT "$ENV{HOME}" STREQUAL "/root") - SET(MKLML_INSTALL_ROOT "$ENV{HOME}") -ENDIF() - +SET(MKLML_DST_DIR "mklml") +SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER}) SET(MKLML_INC_DIR ${MKLML_ROOT}/include) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 534be0abe246ac70950d85ad05441825c8ca768a..41b9b5928958ae31799c396a8d77fd7cff557905 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -187,7 +187,13 @@ function(cc_library TARGET_NAME) endif() # cpplint code style - add_style_check_target(${TARGET_NAME} ${cc_library_SRCS}) + foreach(source_file ${cc_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + add_style_check_target(${TARGET_NAME} ${cc_library_SRCS} ${cc_library_HEADERS}) else(cc_library_SRCS) if (cc_library_DEPS) @@ -239,6 +245,14 @@ function(nv_library TARGET_NAME) add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) target_link_libraries(${TARGET_NAME} ${nv_library_DEPS}) endif() + # cpplint code style + foreach(source_file ${nv_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND cc_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + add_style_check_target(${TARGET_NAME} ${nv_library_SRCS} ${nv_library_HEADERS}) else(nv_library_SRCS) if (nv_library_DEPS) merge_static_libs(${TARGET_NAME} ${nv_library_DEPS}) diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 4b06966fba2bc9f92756be0cb8110bbcd5272423..f8a88cf317aee6c5dd25e4cc25d588c6c50fcbce 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -15,7 +15,6 @@ if(Boost_FOUND) add_subdirectory(platform) add_subdirectory(framework) add_subdirectory(operators) - add_subdirectory(pybind) endif() if(WITH_C_API) diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index c53a5636829cab9d575f58cc2326cb3efe383e1c..7ad8a39768a064140a08c912a5a467bc24a12adf 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -1022,6 +1022,15 @@ void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc, real alpha = 1.0f; real beta = 1.0f; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; + + int batch_size = ((cudnn_tensor_descriptor)inputDesc)->batch_size; + if (batch_size > 1024 && g_cudnn_lib_version < 6000) { + LOG(INFO) << " To process current batch data with size " << batch_size + << " (>1024), cudnnBatchNorm requires cuDNN version >= 6000." + << " If there is an error complaining CUDNN_STATUS_NOT_SUPPORTED," + << " just recompile PaddlePaddle with cuDNN >= 6000, replacing" + << " current version " << g_cudnn_lib_version; + } CHECK_CUDNN( dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle, mode, diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 12a3a00bba35d476fca9c9fb47ac20b87e6f53f2..9c39430835d37d5dfbe4031f29e5a6216ed8b67f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -31,8 +31,14 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) -cc_library(net SRCS net.cc DEPS op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net) - -cc_library(backward SRCS backward.cc DEPS net) +cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward) +cc_library(paddle_pybind SHARED + SRCS pybind.cc + DEPS pybind python backward + fc_op + sgd_op + add_op + mean_op + cross_entropy_op + recurrent_op) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0da11b91a7fe4a98e0832f70095c3200956ff001..c034e265fe4837ca22ab969b0e6952677904e05c 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -14,8 +14,8 @@ #include "paddle/framework/backward.h" #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace framework { @@ -32,7 +32,7 @@ static bool AllInSet(const std::vector& names, } static std::shared_ptr NOP() { - auto net_op = std::make_shared(); + auto net_op = std::make_shared(); net_op->type_ = "@NOP@"; net_op->CompleteAddOp(); return net_op; @@ -42,9 +42,9 @@ static std::shared_ptr NOP() { // // no_grad_names the gradient variable names without gradient calculating. // -// uniq_id is a unique index used inside recursively calling BackwardRecursive. -// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through -// recursive calling. +// uniq_id is a unique index used inside recursively calling +// BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and +// pass `uniq_id` through recursive calling. // // returns The backward operator. For simple situation, it is a simple // operator. For complex situation, it is a NetOp. @@ -64,8 +64,8 @@ std::shared_ptr BackwardRecursive( return NOP(); } - // All output gradients of forwarding operator do not need to calculate. Then - // all input gradients cannot be computed at all, and we put them into + // All output gradients of forwarding operator do not need to calculate. + // Then all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP. if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), no_grad_names)) { @@ -77,14 +77,14 @@ std::shared_ptr BackwardRecursive( } // Returned gradient network - auto net = std::make_shared(); + auto net = std::make_shared(); if (forwardOp.IsNetOp()) { // Because forwardOp is a net op, it can static_cast. - auto& forwardNet = static_cast(forwardOp); + auto& forwardNet = static_cast(forwardOp); - // Map from output gradient variable name to operator's indices in backward - // net. That operator generates that variable. + // Map from output gradient variable name to operator's indices in + // backward net. That operator generates that variable. std::unordered_map> dup_output_ops; size_t local_op_id = 0; @@ -168,6 +168,9 @@ std::shared_ptr Backward( std::unordered_set no_grad_names; no_grad_names.reserve(no_grad_vars.size()); + no_grad_names.insert(OperatorBase::EMPTY_VAR_NAME() + + OperatorBase::GRAD_VAR_SUFFIX()); + for (auto& name : no_grad_vars) { no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); } diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index b095c2c3d5dbf21b5ea70e17475a4aaad9b1db44..8f437e68041188831a17217099e0b0c96432cda4 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -15,8 +15,9 @@ #include "paddle/framework/backward.h" #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" +#include "paddle/operators/type_alias.h" namespace paddle { namespace framework { @@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { } }; -class FcOp : public NetOp { +class FcOp : public ops::NetOp { public: void Init() override { AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, @@ -161,8 +162,8 @@ TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1UL, gop->inputs_.size()); - ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); + ASSERT_EQ(4UL, gop->inputs_.size()); + ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), gop->inputs_[0]); ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); @@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) { auto no_input_gop = f::Backward(*fwd, {"X", "b"}); ASSERT_NE(no_input_gop, nullptr); ASSERT_TRUE(no_input_gop->IsNetOp()); - ASSERT_EQ(0UL, std::static_pointer_cast(no_input_gop)->ops_.size()); + ASSERT_EQ(0UL, + std::static_pointer_cast(no_input_gop)->ops_.size()); } TEST(Backward, net_fc_backward_normal) { @@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) { ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); - auto net = static_cast(gop.get()); + auto net = static_cast(gop.get()); ASSERT_NO_THROW(net->DebugString()); @@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) { ASSERT_NE(fwd, nullptr); std::shared_ptr gop = f::Backward(*fwd, {}); ASSERT_TRUE(gop->IsNetOp()); - auto net = static_cast(gop.get()); + auto net = static_cast(gop.get()); ASSERT_NO_THROW(net->DebugString()); @@ -228,7 +230,7 @@ TEST(Backward, net_fc_backward_not_have_b) { } TEST(Backward, net_input_of_network_not_need_grad) { - f::NetOp net; + ops::NetOp net; net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, @@ -236,7 +238,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { net.CompleteAddOp(); auto bwd = Backward(net, {"X"}); // X@GRAD is not need. ASSERT_TRUE(bwd->IsNetOp()); - auto bwd_net = static_cast(bwd.get()); + auto bwd_net = static_cast(bwd.get()); std::unordered_set all_output = std::unordered_set( bwd_net->outputs_.begin(), bwd_net->outputs_.end()); @@ -253,7 +255,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); - auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); + auto first_fc_grad = static_cast(bwd_net->ops_[1].get()); ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ( f::OperatorBase::EMPTY_VAR_NAME(), @@ -261,14 +263,14 @@ TEST(Backward, net_input_of_network_not_need_grad) { } TEST(Backward, net_shared_weight) { - f::NetOp net; + ops::NetOp net; net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); net.CompleteAddOp(); auto bwd = f::Backward(net, {}); ASSERT_TRUE(bwd->IsNetOp()); - auto bwd_net = static_cast(bwd.get()); + auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ("add", bwd_net->ops_[2]->type_); } @@ -285,7 +287,7 @@ TEST(Backward, op_all_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"X", "b"}); ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); + auto net = static_cast(backward.get()); ASSERT_TRUE(net->ops_.empty()); } @@ -293,7 +295,7 @@ TEST(Backward, op_all_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto backward = f::Backward(*fwd, {"Out"}); ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); + auto net = static_cast(backward.get()); ASSERT_TRUE(net->ops_.empty()); } @@ -301,7 +303,7 @@ TEST(Backward, op_part_of_output_are_not_need) { auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto backward = f::Backward(*fwd, {"Z"}); ASSERT_TRUE(backward->IsNetOp()); - auto net = static_cast(backward.get()); + auto net = static_cast(backward.get()); ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; @@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) { } TEST(Backward, linear_net_intermediate_variable_has_no_grad) { - f::NetOp net; + ops::NetOp net; net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"mul_out1", "add_out1", "out1"}, {})); net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, @@ -351,14 +353,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { net.CompleteAddOp(); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); ASSERT_TRUE(backward->IsNetOp()); - auto bwd_net = static_cast(backward.get()); + auto bwd_net = static_cast(backward.get()); ASSERT_EQ(bwd_net->ops_.size(), 3UL); auto &grad_fc = *bwd_net->ops_[0]; EXPECT_EQ(grad_fc.inputs_.size(), 3UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ - - 1UL /*ignoreGradient varable number*/ + 2U /* internal variable number*/); EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add */ diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 9fcc657edcd5459d0a42a64d708603a4bcd53cf0..5aa5af0c19be5a209c760282cb1a090fc57a53ad 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -25,18 +25,15 @@ limitations under the License. */ namespace paddle { namespace framework { -namespace { -typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, - Dim<8>, Dim<9>> - DDimVar; -} - /** * \brief A dynamically sized dimension. * * The number of dimensions must be between [1, 9]. */ struct DDim { + typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, + Dim<8>, Dim<9>> + DDimVar; DDimVar var; DDim() : var(Dim<1>()) {} diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index dd686cc78246f06cdc3ec7d013086863d7e8fac0..ea5e939c6e26514c2f3c515da5581b29103f75b6 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -8,107 +8,97 @@ You may obtain a copy of the License at Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either +express or implied. See the License for the specific language governing +permissions and limitations under the License. */ #include "paddle/framework/grad_op_builder.h" +#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { -OperatorBase* GradOpBuilder::Build() { - BuildOpInOutArgList(); - std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); - OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); - grad_op->type_ = grad_op_type; - CompleteGradOp(grad_op); - return grad_op; -} +class OpRegistry; + +using VarIndexMap = std::unordered_map; -OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, - const VarIndexMap& var_map, - const std::vector& format, - InOutType type) { - int idx = var_map.at(var.name()); - int begin_idx = format.empty() ? idx : format.at(idx); - int end_idx = format.empty() ? idx + 1 : format.at(idx + 1); - return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx, - end_idx); +enum class OpArgType { IN, OUT }; + +static std::vector* GetOpFormat(OperatorBase* op, const OpArgType& type) { + std::string key = type == OpArgType::IN ? "input_format" : "output_format"; + return op->attrs_.count(key) + ? &boost::get>(op->attrs_.at(key)) + : nullptr; } -void GradOpBuilder::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_.type_); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); - const std::vector& in_format = - op_.attrs_.count("input_format") - ? op_.GetAttr>("input_format") - : std::vector(); - const std::vector& out_format = - op_.attrs_.count("output_format") - ? op_.GetAttr>("output_format") - : std::vector(); - for (const auto& var : op_proto.inputs()) { - arg_list_.emplace_back( - std::shared_ptr(BuildArg(var, var_map, in_format, IN))); - } - for (const auto& var : op_proto.outputs()) { - arg_list_.emplace_back( - std::shared_ptr(BuildArg(var, var_map, out_format, OUT))); - } +static const std::vector* GetOpFormat(const OperatorBase* op, + const OpArgType& type) { + std::string key = type == OpArgType::IN ? "input_format" : "output_format"; + return op->attrs_.count(key) + ? &boost::get>(op->attrs_.at(key)) + : nullptr; } -void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, - std::vector& in_out, - std::vector& format, - VarIndexMap* varmap, int& idx, - bool is_grad) const { - std::string var_name = arg->proto_name_; - if (is_grad) { - var_name += OperatorBase::GRAD_VAR_SUFFIX(); - } - (*varmap)[var_name] = idx++; - size_t pre_sz = in_out.size(); - auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); - std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, - std::back_inserter(in_out)); - if (is_grad) { - for (size_t i = pre_sz; i < in_out.size(); ++i) { - in_out[i] += OperatorBase::GRAD_VAR_SUFFIX(); +static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, + const OpArgType& src_type, const OpArgType& dst_type, + int& idx, bool is_grad) { + const std::vector& src_inout = + src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; + const std::vector* src_format = GetOpFormat(src_op, src_type); + + std::vector& dst_inout = + dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; + std::vector* dst_format = GetOpFormat(dst_op, dst_type); + const OpProto& proto = OpRegistry::protos().at(src_op->type_); + const auto& src_arg_list = + src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); + + for (const auto& arg : src_arg_list) { + std::string src_name = arg.name(); + std::string dst_name = + is_grad ? src_name + OperatorBase::GRAD_VAR_SUFFIX() : src_name; + (*dst_op->in_out_idxs_)[dst_name] = idx++; + int src_arg_idx = src_op->in_out_idxs_->at(src_name); + int src_begin = + src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); + int src_end = src_format == nullptr ? src_arg_idx + 1 + : src_format->at(src_arg_idx + 1); + for (int i = src_begin; i < src_end; ++i) { + std::string s = is_grad ? src_inout[i] + OperatorBase::GRAD_VAR_SUFFIX() + : arg.ignore_gradient() + ? OperatorBase::EMPTY_VAR_NAME() + : src_inout[i]; + dst_inout.emplace_back(s); + } + if (dst_format != nullptr) { + dst_format->push_back(dst_inout.size()); } } - format.push_back(in_out.size()); } -void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->attrs_ = op_.attrs_; +OperatorBase* BuildGradOp(const OperatorBase* op) { + std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); + OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); + grad_op->type_ = grad_op_type; + grad_op->attrs_ = op->attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); - VarIndexMap* grad_varmap = new VarIndexMap(); + if (GetOpFormat(op, OpArgType::IN) != nullptr) { + grad_op->attrs_["output_format"] = std::vector({0}); + } + if (GetOpFormat(op, OpArgType::IN) != nullptr || + GetOpFormat(op, OpArgType::OUT) != nullptr) { + grad_op->attrs_["input_format"] = std::vector({0}); + } + grad_op->in_out_idxs_.reset(new VarIndexMap()); int in_idx = 0; int out_idx = 0; - std::vector in_format({0}); - std::vector out_format({0}); - for (const auto& arg : arg_list_) { - // op_'s inputs_ and outputs_ - if (arg->needed_in_grad_) { - AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, - in_idx, false); - } - if (arg->type_ == IN) { - // gradients of op_'s inputs_ - AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap, - out_idx, true); - } else { - // gradients of op_'s outputs_ - AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, - in_idx, true); - } - } - grad_op->attrs_["input_format"] = in_format; - grad_op->attrs_["output_format"] = out_format; - grad_op->in_out_idxs_.reset(grad_varmap); + TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I + TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G + TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG + TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG + return grad_op; } } // namespace framework diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index cc7a76f3726e00a08fbe06bca4c9b9f5bad466b4..998f8ebbb5f2f4fb8b7e938b5916afd0f8a7930d 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -1,48 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once -#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/operator.h" namespace paddle { namespace framework { -class OpRegistry; - -enum InOutType { IN, OUT }; - -struct OpInOutArg { - OpInOutArg(const std::string& proto_name, const InOutType& type, - bool needed_in_grad, size_t begin_idx, size_t end_idx) - : proto_name_(proto_name), - type_(type), - needed_in_grad_(needed_in_grad), - begin_idx_(begin_idx), - end_idx_(end_idx) {} - - std::string proto_name_; - InOutType type_; - bool needed_in_grad_; - size_t begin_idx_; - size_t end_idx_; -}; - -class GradOpBuilder { - using VarIndexMap = std::unordered_map; - - public: - GradOpBuilder(const OperatorBase& op) : op_(op) {} - OperatorBase* Build(); - - private: - OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, - const std::vector& format, InOutType type); - void BuildOpInOutArgList(); - void AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, - std::vector& format, VarIndexMap* varmap, int& idx, - bool is_grad) const; - void CompleteGradOp(OperatorBase* grad_op) const; - const OperatorBase& op_; - std::vector> arg_list_; -}; + +OperatorBase* BuildGradOp(const OperatorBase* op); } // namespace framework } // namespace paddle diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index e9cf3b9798db2cbfb8d26259ae9a6741fbae8278..96d7f309d67b15c000ab8ce3769931322fbca880 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -8,10 +8,49 @@ USE_OP(add_two); namespace paddle { namespace framework { +class NOP : public OperatorBase { + public: + void InferShape(const Scope &scope) const override {} + void Run(const Scope &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class MutiInOutOpMaker : public OpProtoAndCheckerMaker { + public: + MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("In1", "a single input"); + AddInput("In2_mult", "a multiple input").SetMultiple(); + AddInput("In3", "another single input"); + AddOutput("Out1", "a single output"); + AddOutput("Out2_mult", "a multiple output").SetMultiple(); + AddComment("test op with multiple inputs and outputs"); + } +}; + +class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { + public: + IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("In1", "a single input"); + AddInput("In2_mult", "a multiple input").SetMultiple().IgnoreGradient(); + AddInput("In3_mult", "another multiple input").SetMultiple(); + AddOutput("Out1_mult", "a multiple output").SetMultiple(); + AddOutput("Out2", "a single output").IgnoreGradient(); + AddComment("op with inputs and outputs ignored in gradient calculating"); + } +}; + +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; + TEST(GradOpBuilder, AddTwo) { - std::shared_ptr add_op( - OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); + std::shared_ptr add_op( + f::OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + std::shared_ptr grad_add_op = + f::OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); @@ -22,5 +61,85 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); } -} // namespace framework -} // namespace paddle \ No newline at end of file +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); +REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); +REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); + +TEST(GradOpBuilder, MutiInOut) { + f::AttributeMap attrs{{"input_format", std::vector{0, 1, 4, 5}}, + {"output_format", std::vector{0, 1, 3}}}; + std::shared_ptr test_op(f::OpRegistry::CreateOp( + "mult_io", {"in1", "in2_1", "in2_2", "in2_3", "in3"}, + {"out1", "out2_1", "out2_2"}, attrs)); + std::shared_ptr grad_test_op = + f::OpRegistry::CreateGradOp(*test_op); + + ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + EXPECT_EQ(grad_test_op->Input("In1"), "in1"); + EXPECT_EQ(grad_test_op->Inputs("In2_mult"), + std::vector({"in2_1", "in2_2", "in2_3"})); + EXPECT_EQ(grad_test_op->Input("In3"), "in3"); + EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); + EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), + std::vector({"out2_1", "out2_2"})); + EXPECT_EQ(grad_test_op->Input("Out1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Inputs("Out2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector( + {"out2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "out2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + + ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_3" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ(grad_test_op->Output("In3" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in3" + f::OperatorBase::GRAD_VAR_SUFFIX()); +} + +TEST(GradOpBuilder, IOIgnoredInGradient) { + f::AttributeMap attrs{{"input_format", std::vector{0, 1, 3, 5}}, + {"output_format", std::vector{0, 2, 3}}}; + std::shared_ptr test_op(f::OpRegistry::CreateOp( + "io_ignored", {"in1", "in2_1", "in2_2", "in3_1", "in3_2"}, + {"out1_1", "out1_2", "out2"}, attrs)); + std::shared_ptr grad_test_op = + f::OpRegistry::CreateGradOp(*test_op); + + // 'In2' and 'Out2' are ignored in gradient calculating + ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + EXPECT_EQ(grad_test_op->Input("In1"), "in1"); + EXPECT_EQ(grad_test_op->Inputs("In2_mult"), + std::vector({f::OperatorBase::EMPTY_VAR_NAME(), + f::OperatorBase::EMPTY_VAR_NAME()})); + EXPECT_EQ(grad_test_op->Inputs("In3_mult"), + std::vector({"in3_1", "in3_2"})); + EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), + std::vector({"out1_1", "out1_2"})); + EXPECT_EQ(grad_test_op->Input("Out2"), f::OperatorBase::EMPTY_VAR_NAME()); + EXPECT_EQ( + grad_test_op->Inputs("Out1_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector( + {"out1_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "out1_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ(grad_test_op->Input("Out2" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out2" + f::OperatorBase::GRAD_VAR_SUFFIX()); + + ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); + EXPECT_EQ(grad_test_op->Output("In1" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "in1" + f::OperatorBase::GRAD_VAR_SUFFIX()); + EXPECT_EQ( + grad_test_op->Outputs("In2_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in2_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in2_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); + EXPECT_EQ( + grad_test_op->Outputs("In3_mult" + f::OperatorBase::GRAD_VAR_SUFFIX()), + std::vector({"in3_1" + f::OperatorBase::GRAD_VAR_SUFFIX(), + "in3_2" + f::OperatorBase::GRAD_VAR_SUFFIX()})); +} diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f10c9297981a4c6aefc6c2072d0ac2b8e562a7a0..1af894612c526fd57b5b6f1d26d934aac27493a9 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -306,8 +306,7 @@ class OpRegistry { static std::shared_ptr CreateGradOp(const OperatorBase& op) { PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - GradOpBuilder builder(op); - std::shared_ptr grad_op(builder.Build()); + std::shared_ptr grad_op(BuildGradOp(&op)); grad_op->Init(); return grad_op; } @@ -315,7 +314,7 @@ class OpRegistry { static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; - }; + } static std::unordered_map& grad_ops() { static std::unordered_map grad_ops_; @@ -337,7 +336,7 @@ class OpRegistry { static std::unordered_map& op_checkers() { static std::unordered_map op_checkers_; return op_checkers_; - }; + } static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); @@ -354,7 +353,7 @@ class OpRegistry { template class OpRegisterHelper { public: - OpRegisterHelper(const char* op_type) { + explicit OpRegisterHelper(const char* op_type) { OpRegistry::RegisterOp(op_type); } }; @@ -400,6 +399,14 @@ class GradOpRegisterHelper { return 0; \ } +/** + * Macro to Forbid user register Gradient Operator. + */ +#define NO_GRADIENT(__op_type) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type##__op_type##_grad, \ + "NO_GRADIENT must be in global namespace") + /** * Macro to Register OperatorKernel. */ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 55435103489ace11868eed61c38018d8ba357e65..fbf9113e5677080e6573ba3ff1e1deb36a2889e9 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -55,6 +55,10 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + static std::string GRAD_VAR_NAME(const std::string& name) { + return name + GRAD_VAR_SUFFIX(); + } + /// Variables with this suffix are supposed to be filled up with zeros. static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } @@ -280,7 +284,7 @@ class OperatorWithKernel : public OperatorBase { platform::Place place_; OpKernelKey() = default; - OpKernelKey(const platform::DeviceContext& dev_ctx) { + explicit OpKernelKey(const platform::DeviceContext& dev_ctx) { place_ = dev_ctx.GetPlace(); } diff --git a/paddle/pybind/pybind.cc b/paddle/framework/pybind.cc similarity index 59% rename from paddle/pybind/pybind.cc rename to paddle/framework/pybind.cc index 40ff164497f627c0b562b6d33bfb4bec590e4c85..b9889e483e27e9dad3310a34b5306f073ad1887d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/framework/pybind.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -17,19 +17,19 @@ limitations under the License. */ #include #include "paddle/framework/backward.h" -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/framework/tensor_py.h" +#include "paddle/operators/net_op.h" +#include "paddle/operators/type_alias.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" -#include "paddle/pybind/tensor_bind.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" namespace py = pybind11; -namespace pd = paddle::framework; USE_OP(add_two); USE_OP(onehot_cross_entropy); @@ -41,17 +41,18 @@ USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); USE_OP_WITHOUT_KERNEL(recurrent_op); - +namespace paddle { +namespace framework { template -void ExposeOperator(ClassType& m) { +void ExposeOperator(ClassType &m) { m.def("infer_shape", &ClassType::type::InferShape) .def("run", &ClassType::type::Run) .def("type", - [](const typename ClassType::type& op) -> std::string { + [](const typename ClassType::type &op) -> std::string { return op.type_; }) .def("outputs", - [](const typename ClassType::type& op) -> std::vector { + [](const typename ClassType::type &op) -> std::vector { return op.outputs_; }) .def("__str__", &ClassType::type::DebugString); @@ -73,80 +74,81 @@ bool IsCompileGPU() { PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); - py::class_(m, "Tensor", py::buffer_protocol()) - .def_buffer([](pd::Tensor& self) -> py::buffer_info { - return paddle::pybind::CastToPyBuffer(self); - }) + py::class_(m, "Tensor", py::buffer_protocol()) + .def_buffer( + [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("get_dims", - [](const pd::Tensor& self) { return pd::vectorize(self.dims()); }) + [](const Tensor &self) { return vectorize(self.dims()); }) .def("set_dims", - [](pd::Tensor& self, const std::vector& dim) { - self.Resize(pd::make_ddim(dim)); + [](Tensor &self, const std::vector &dim) { + self.Resize(make_ddim(dim)); }) .def("alloc_float", - [](pd::Tensor& self, paddle::platform::GPUPlace& place) { + [](Tensor &self, paddle::platform::GPUPlace &place) { self.mutable_data(place); }) .def("alloc_float", - [](pd::Tensor& self, paddle::platform::CPUPlace& place) { + [](Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("alloc_int", - [](pd::Tensor& self, paddle::platform::CPUPlace& place) { + [](Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("alloc_int", - [](pd::Tensor& self, paddle::platform::GPUPlace& place) { + [](Tensor &self, paddle::platform::GPUPlace &place) { self.mutable_data(place); }) - .def("set", paddle::pybind::PyCPUTensorSetFromArray) - .def("set", paddle::pybind::PyCPUTensorSetFromArray) + .def("set", PyCPUTensorSetFromArray) + .def("set", PyCPUTensorSetFromArray) #ifndef PADDLE_ONLY_CPU - .def("set", paddle::pybind::PyCUDATensorSetFromArray) - .def("set", paddle::pybind::PyCUDATensorSetFromArray) + .def("set", PyCUDATensorSetFromArray) + .def("set", PyCUDATensorSetFromArray) #endif - .def("shape", - [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); + .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) + .def("set_float_element", + [](Tensor &self, size_t offset, float f) { + // TODO(yuyang18): Only support GPU now. + self.data()[offset] = f; + }) + .def("get_float_element", [](Tensor &self, size_t offset) -> float { + // TODO(yuyang18): Only support GPU now. + return self.data()[offset]; + }); - py::class_(m, "Variable", R"DOC(Variable Class. + py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. )DOC") - .def("is_int", [](const pd::Variable& var) { return var.IsType(); }) + .def("is_int", [](const Variable &var) { return var.IsType(); }) .def("set_int", - [](pd::Variable& var, int val) -> void { - *var.GetMutable() = val; - }) - .def("get_int", - [](const pd::Variable& var) -> int { return var.Get(); }) + [](Variable &var, int val) -> void { *var.GetMutable() = val; }) + .def("get_int", [](const Variable &var) -> int { return var.Get(); }) .def("get_tensor", - [](pd::Variable& self) -> pd::Tensor* { - return self.GetMutable(); - }, + [](Variable &self) -> Tensor * { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_net", - [](pd::Variable& self) -> pd::NetOp* { - return self.GetMutable(); + [](Variable &self) -> ops::NetOp * { + return self.GetMutable(); }, py::return_value_policy::reference); - py::class_(m, "Scope", "") + py::class_(m, "Scope", "") .def("new_var", - [](pd::Scope& self, const std::string& name) -> pd::Variable* { + [](Scope &self, const std::string &name) -> Variable * { return self.NewVar(name); }, py::return_value_policy::reference) - .def("find_var", &pd::Scope::FindVar, py::return_value_policy::reference) + .def("find_var", &Scope::FindVar, py::return_value_policy::reference) .def(py::init<>()) - .def("new_scope", - [](pd::Scope& self) -> pd::Scope* { return &self.NewScope(); }, + .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, py::return_value_policy::reference) - .def("drop_kids", &pd::Scope::DropKids); + .def("drop_kids", &Scope::DropKids); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto& protos = pd::OpRegistry::protos(); + auto &protos = OpRegistry::protos(); std::vector ret_values; for (auto it = protos.begin(); it != protos.end(); ++it) { PADDLE_ENFORCE(it->second.IsInitialized(), @@ -161,8 +163,8 @@ All parameter, weight, gradient are variables in Paddle. m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") - .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) - .def("temp", pd::OperatorBase::TMP_VAR_NAME); + .def("empty", OperatorBase::EMPTY_VAR_NAME) + .def("temp", OperatorBase::TMP_VAR_NAME); // clang-format off py::class_(m, "DeviceContext") .def_static("create", @@ -185,43 +187,45 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "CPUPlace").def(py::init<>()); - py::class_> operator_base( + py::class_> operator_base( m, "Operator"); operator_base.def_static("create", [](py::bytes protobin) { - pd::OpDesc desc; + OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc"); PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); + return OpRegistry::CreateOp(desc); }); operator_base.def("backward", - [](const pd::OperatorBase& forwardOp, - const std::unordered_set& no_grad_vars) { - return pd::Backward(forwardOp, no_grad_vars); + [](const OperatorBase &forwardOp, + const std::unordered_set &no_grad_vars) { + return Backward(forwardOp, no_grad_vars); }); ExposeOperator(operator_base); - py::class_> net(m, "Net"); + py::class_> net(m, "Net"); net.def_static("create", - []() -> std::shared_ptr { - auto retv = std::make_shared(); + []() -> std::shared_ptr { + auto retv = std::make_shared(); retv->type_ = "plain_net"; return retv; }) - .def("add_op", &pd::NetOp::AddOp) - .def("add_op", - [](pd::NetOp& self, const std::shared_ptr& net) -> void { - self.AddOp(std::static_pointer_cast(net)); - }) - .def("complete_add_op", &pd::NetOp::CompleteAddOp) + .def("add_op", &ops::NetOp::AddOp) + .def( + "add_op", + [](ops::NetOp &self, const std::shared_ptr &net) -> void { + self.AddOp(std::static_pointer_cast(net)); + }) + .def("complete_add_op", &ops::NetOp::CompleteAddOp) .def("complete_add_op", - [](std::shared_ptr& self) { self->CompleteAddOp(); }); + [](std::shared_ptr &self) { self->CompleteAddOp(); }); + ExposeOperator(net); m.def("unique_integer", UniqueIntegerGenerator); @@ -230,3 +234,5 @@ All parameter, weight, gradient are variables in Paddle. return m.ptr(); } +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 76070f636b0971f4a136042e056c59adb5dc2d40..4c3b14b83d841e88683a13634c93f51c012128b6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -26,19 +26,17 @@ limitations under the License. */ #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { -namespace pybind { -namespace details { // forward declare -template -struct CastToPyBufferImpl; -} // namespace details -} // namespace pybind namespace framework { +namespace details { +template +struct CastToPyBufferImpl; +} class Tensor { public: template - friend struct paddle::pybind::details::CastToPyBufferImpl; + friend struct details::CastToPyBufferImpl; template friend struct EigenTensor; @@ -167,4 +165,4 @@ class Tensor { } // namespace framework } // namespace paddle -#include "paddle/framework/detail/tensor-inl.h" +#include "paddle/framework/tensor_impl.h" diff --git a/paddle/framework/detail/tensor-inl.h b/paddle/framework/tensor_impl.h similarity index 100% rename from paddle/framework/detail/tensor-inl.h rename to paddle/framework/tensor_impl.h diff --git a/paddle/pybind/tensor_bind.h b/paddle/framework/tensor_py.h similarity index 92% rename from paddle/pybind/tensor_bind.h rename to paddle/framework/tensor_py.h index def37219ccefd5435f1212c4e4daac5a351d76f4..4e1ab77b157fe1adaeac55c271c056236f2d40de 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/framework/tensor_py.h @@ -23,7 +23,7 @@ namespace py = pybind11; namespace paddle { -namespace pybind { +namespace framework { namespace details { @@ -63,11 +63,8 @@ struct CastToPyBufferImpl { } return py::buffer_info( dst_tensor.mutable_data(dst_tensor.holder_->place()), - sizeof(CUR_TYPE), - py::format_descriptor::format(), - (size_t)framework::arity(dst_tensor.dims()), - dims_outside, - strides); + sizeof(CUR_TYPE), py::format_descriptor::format(), + (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); @@ -110,8 +107,8 @@ void PyCUDATensorSetFromArray( self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(place); - paddle::platform::GpuMemcpySync( - dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); + paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(), + cudaMemcpyHostToDevice); } #endif diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index bb4f48364b9b454af7d37fe4d3c340666e53285c..baf78bc6c88d0d294f4457b81c52b22e425d9fdb 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -109,6 +109,13 @@ protected: return filter[filter.ndims() - 1]; } + // determine whether im2col needs to be performed + inline bool isNeedIm2col(const TensorShape& filter) const { + return !(getFilterHeight(filter) == 1 && getFilterWidth(filter) == 1 && + strideH() == 1 && strideW() == 1 && paddingH() == 0 && + paddingW() == 0); + } + std::vector strides_; std::vector paddings_; diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 9deb2739fcfff935a98a0b5b31b5d11819d81227..0ada4d70a0c7d13f9b5fb1a42eac07fc4c775a87 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -66,16 +66,23 @@ public: real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -86,15 +93,18 @@ public: for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; @@ -159,19 +169,27 @@ public: real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); real* inputGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Col2ImFunctor col2im; GemmFunctor gemm; + size_t inputOffset = imShape.getElements(); size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -182,6 +200,11 @@ public: int K = outputChannels / groups_; int N = outputHeight * outputWidth; int M = inputChannels / groups_ * filterHeight * filterWidth; + real scale = 0.0f; + if (!needIm2col) { + colData = inputGrad + g * inputOffset; + scale = 1.0f; + } gemm(CblasTrans, CblasNoTrans, M, @@ -192,17 +215,19 @@ public: M, outputGrad + g * outputOffset, N, - 0.0f, + scale, colData, N); - col2im(inputGrad + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); + if (needIm2col) { + col2im(inputGrad + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } } inputGrad += inputChannels * inputHeight * inputWidth; outputGrad += outputChannels * outputHeight * outputWidth; @@ -255,16 +280,23 @@ public: real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + TensorShape imShape = TensorShape({inputChannels / groups_, inputHeight, inputWidth}); - TensorShape colShape = TensorShape({inputChannels / groups_, - filterHeight, - filterWidth, - outputHeight, - outputWidth}); - resizeBuffer(colShape.getElements()); - real* colData = reinterpret_cast(memory_->getBuf()); + TensorShape colShape; + real* colData = NULL; + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + resizeBuffer(colShape.getElements()); + colData = reinterpret_cast(memory_->getBuf()); + } Im2ColFunctor im2col; GemmFunctor gemm; @@ -274,15 +306,18 @@ public: size_t filterOffset = filter.getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { for (size_t g = 0; g < groups_; g++) { - im2col(inputData + g * inputOffset, - imShape, - colData, - colShape, - strideH(), - strideW(), - paddingH(), - paddingW()); - + if (needIm2col) { + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW()); + } else { + colData = inputData + g * inputOffset; + } int M = outputChannels / groups_; int K = outputHeight * outputWidth; int N = inputChannels / groups_ * filterHeight * filterWidth; diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index 9ddd449de7500f5682d59469328f06971c6e83bf..f98bf95064fa539b990309dfe0bff10c1e99d096 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -967,8 +967,9 @@ void RecurrentGradientMachine::generateSequence() { size_t numSequences = getGenBatchSize(); resizeBootFrame(numSequences); - // We create only two sub-network in generation for alternate use. - // Thus, we can reduce total memory of output_ in layer forward. + // We create only two sub-network in generation, one stores states of all + // layers in previous time step and the other storing the states at current + // time step. resizeOrCreateFrames(2); // outFrameLines_.size() > 1UL @@ -1001,10 +1002,9 @@ void RecurrentGradientMachine::generateSequence() { // init outArg size_t resultNum = generator_.config.num_results_per_sample(); - IVector::resizeOrCreate( - generator_.outArg.ids, - generator_.config.max_num_frames() * numSequences * resultNum, - false); + size_t maxGenWordCount = + generator_.config.max_num_frames() * numSequences * resultNum; + IVector::resizeOrCreate(generator_.outArg.ids, maxGenWordCount, false); if (resultNum > 1) { CHECK_LE(resultNum, static_cast(generator_.config.beam_size())); Matrix::resizeOrCreate(generator_.outArg.in, @@ -1012,6 +1012,11 @@ void RecurrentGradientMachine::generateSequence() { /* width */ resultNum, false, /* useGpu */ false); + Matrix::resizeOrCreate(generator_.outArg.value, + /* height */ maxGenWordCount, + /* width */ 1, + false, + /* useGpu */ false); } ICpuGpuVector::resizeOrCreate(generator_.outArg.sequenceStartPositions, numSequences + 1, @@ -1313,13 +1318,20 @@ void RecurrentGradientMachine::fillGenOutputs() { starts[0] = 0; if (numResults > 1) { real* probs = generator_.outArg.in->getData(); + real* idsProb = generator_.outArg.value->getData(); + size_t curPos = 0; for (size_t i = 0; i < finalPaths_.size(); ++i) { for (size_t j = 0; j < finalPaths_[i].size(); ++j) { Path& path = finalPaths_[i][j]; - generator_.ids.push_back(path.ids.size()); // sequence size + size_t genLen = path.ids.size(); + generator_.ids.push_back(genLen); // sequence size generator_.ids.insert( generator_.ids.end(), path.ids.begin(), path.ids.end()); generator_.ids.push_back(-1); // end of sequence + + memcpy(idsProb + curPos, path.idsProb.data(), sizeof(real) * genLen); + curPos += genLen; + idsProb[curPos++] = -1.0; probs[i * numResults + j] = path.logProb; if (!j && dataArgsSize_) { diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h index f245620cf668bb341df99cf498105cbd996a6b24..fb3fc5877ac96323e891f800db80af83b6809831 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h @@ -189,6 +189,11 @@ public: */ std::vector ids; + /** + * @brief idsProb, log probability of each generated words. + */ + std::vector idsProb; + /** * @brief logProb, current probability of path. */ @@ -228,11 +233,13 @@ public: */ Path(Path& old, int newId, real logProb, int machineId, int topIndex) : ids(old.ids), + idsProb(old.idsProb), logProb(old.logProb + logProb), machineId(machineId), topIndex(topIndex), seqId(old.seqId) { ids.push_back(newId); + idsProb.push_back(logProb); if (!old.probHistory.empty()) { this->probHistory = old.probHistory; // probHistory store current prob, not sum @@ -411,8 +418,9 @@ protected: struct Generator { GeneratorConfig config; - std::vector ids; // store generated sequences - Argument outArg; // final output argument + std::vector ids; // store generated sequences + std::vector idsProb; // log probability of each generated word + Argument outArg; // final output argument }; bool generating_; Generator generator_; diff --git a/paddle/gserver/layers/KmaxSeqScoreLayer.cpp b/paddle/gserver/layers/KmaxSeqScoreLayer.cpp index d747db9b4a72065acd7eadf3614e8ae0b2fa91d5..8ce591d4762466e1ed4b2970cb9cae9203bc0a2b 100644 --- a/paddle/gserver/layers/KmaxSeqScoreLayer.cpp +++ b/paddle/gserver/layers/KmaxSeqScoreLayer.cpp @@ -39,12 +39,13 @@ REGISTER_LAYER(kmax_seq_score, KmaxSeqScoreLayer); bool KmaxSeqScoreLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { bool ret = Layer::init(layerMap, parameterMap); - CHECK_EQ(1UL, inputLayers_.size()); + CHECK_EQ(1U, inputLayers_.size()); beamSize_ = config_.beam_size(); - CHECK_GE(beamSize_, 1LU); + CHECK_GE(beamSize_, 1U); setNeedSequenceInfo(false); + setNeedGradient(false); return ret; } @@ -96,18 +97,19 @@ void KmaxSeqScoreLayer::forward(PassType passType) { scores_ = inputScore; } - MatrixPtr outputValue = getOutputValue(); Matrix::resizeOrCreate( - outputValue, + output_.value, input.hasSubseq() ? input.getNumSubSequences() : input.getNumSequences(), - beamSize_); - outputValue->one(); - outputValue->mulScalar(-1.); + beamSize_, + false, + false); + output_.value->one(); + output_.value->mulScalar(-1.); kmaxScorePerSeq(scores_->getData(), output_.value->getData(), - input.hasSeq() ? input.subSequenceStartPositions - : input.sequenceStartPositions); + input.hasSubseq() ? input.subSequenceStartPositions + : input.sequenceStartPositions); } void KmaxSeqScoreLayer::backward(const UpdateCallback& callback) {} diff --git a/paddle/gserver/tests/LayerGradUtil.cpp b/paddle/gserver/tests/LayerGradUtil.cpp index 9eca58f1a1baa6fb1c404a91a345bc7f9d6b4acc..fd9cfa1dc7a9028cb2c5c98baca98ffb2a837bac 100644 --- a/paddle/gserver/tests/LayerGradUtil.cpp +++ b/paddle/gserver/tests/LayerGradUtil.cpp @@ -400,7 +400,6 @@ void initDataLayer(TestConfig testConf, const std::vector& labelSeqStartPositions = testConf.inputDefs[i].labelSeqStartPositions; if (labelSeqStartPositions.size() != 0) { - CHECK(!sequenceStartPositions); CHECK_GE(static_cast(labelSeqStartPositions.size()), 2); sequenceStartPositions = @@ -410,6 +409,19 @@ void initDataLayer(TestConfig testConf, useGpu); data.sequenceStartPositions = sequenceStartPositions; } + + const std::vector& labelSubSeqStartPositions = + testConf.inputDefs[i].labelSubSeqStartPositions; + if (labelSubSeqStartPositions.size() != 0) { + CHECK_GE(static_cast(labelSubSeqStartPositions.size()), 2); + + subSequenceStartPositions = + ICpuGpuVector::create(labelSubSeqStartPositions.size(), useGpu); + subSequenceStartPositions->copyFrom(labelSubSeqStartPositions.data(), + labelSubSeqStartPositions.size(), + useGpu); + data.subSequenceStartPositions = subSequenceStartPositions; + } break; } default: diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index d299b4dd09418589514d99a72f83e1103ace7de1..5debedf5ef6a3262578ca01b335e664f9a334d35 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -67,6 +67,7 @@ struct InputDef { bool isStatic; std::vector labelInitValue; std::vector labelSeqStartPositions; + std::vector labelSubSeqStartPositions; MatrixPtr selfDefinedData; InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { @@ -81,8 +82,10 @@ struct InputDef { InputDef(InputType type, string nameIn, MatrixPtr selfDefinedData, - std::vector selfDefinedSeqStartPos = {}) + std::vector selfDefinedSeqStartPos = {}, + std::vector selfDefinedSubSeqStartPos = {}) : labelSeqStartPositions(selfDefinedSeqStartPos), + labelSubSeqStartPositions(selfDefinedSubSeqStartPos), selfDefinedData(selfDefinedData) { inputType = type; name = nameIn; diff --git a/paddle/gserver/tests/test_KmaxSeqScore.cpp b/paddle/gserver/tests/test_KmaxSeqScore.cpp index e3530977c6fb99d62b5ec0f134a596e844cbd61c..f958b4974d45ef65f8f374148a31ad3a6ce7632f 100644 --- a/paddle/gserver/tests/test_KmaxSeqScore.cpp +++ b/paddle/gserver/tests/test_KmaxSeqScore.cpp @@ -32,7 +32,6 @@ DECLARE_int32(gpu_id); DECLARE_bool(thread_local_rand_use_global_seed); vector randSampling(int range, int n) { - srand(1); CHECK_GE(range, n); vector num(range); iota(begin(num), end(num), 0); @@ -45,7 +44,7 @@ vector randSampling(int range, int n) { void genRandomSeqInfo(vector& seqStartPosition, vector& subSeqStartPosition) { - const int maxSeqNum = 5; + const int maxSeqNum = 100; // generate random start position information int seqNum = 1 + (rand() % maxSeqNum); seqStartPosition.resize(seqNum + 1, 0); @@ -61,78 +60,93 @@ void genRandomSeqInfo(vector& seqStartPosition, void genRandomGroundTruth(real* values, vector>& groundTruth, - vector& seqStartPosition, - vector& subSeqStartPosition, - bool useSubseqInfo, + vector& startPos, size_t beamSize) { - auto genData = [&](real* values, vector& startPos, size_t beamSize) { - groundTruth.resize(startPos.size() - 1, vector(beamSize, -1)); - - for (size_t i = 0; i < startPos.size() - 1; ++i) { - int seqLen = startPos[i + 1] - startPos[i]; - vector pos = - randSampling(seqLen, min(static_cast(beamSize), seqLen)); - for (size_t j = 0; j < pos.size(); ++j) { - groundTruth[i][j] = pos[j]; - values[subSeqStartPosition[i] + pos[j]] = 1.; - } + groundTruth.resize(startPos.size() - 1, vector(beamSize, -1)); + for (size_t i = 0; i < startPos.size() - 1; ++i) { + int seqLen = startPos[i + 1] - startPos[i]; + vector pos = + randSampling(seqLen, min(static_cast(beamSize), seqLen)); + for (size_t j = 0; j < pos.size(); ++j) { + groundTruth[i][j] = pos[j]; + values[startPos[i] + pos[j]] = 1.; } - }; + } +} - if (useSubseqInfo) - genData(values, subSeqStartPosition, beamSize); - else - genData(values, seqStartPosition, beamSize); +void checkLayerOut(vector> groundTruth, + real* layerOut, + size_t beamSize) { + for (size_t i = 0; i < groundTruth.size(); ++i) { + int begPos = i * beamSize; + vector tmp(layerOut + begPos, layerOut + begPos + beamSize); + sort(begin(tmp), end(tmp)); + sort(begin(groundTruth[i]), end(groundTruth[i])); + for (size_t j = 0; j < beamSize; ++j) CHECK_EQ(tmp[j], groundTruth[i][j]); + } } -// Test that the batchNormLayer can be followed by a ConvLayer TEST(Layer, kmaxSeqScoreLayer) { - const size_t beamSize = 5; + const size_t maxBeamSize = 100; + int beamSize = 1 + (rand() % maxBeamSize); vector seqStartPosition; vector subSeqStartPosition; genRandomSeqInfo(seqStartPosition, subSeqStartPosition); MatrixPtr inValue = Matrix::create(subSeqStartPosition.back(), 1, false, false); - inValue->randomizeUniform(); for (auto hasSubseq : {false, true}) { vector> groundTruth; + inValue->randomizeUniform(); genRandomGroundTruth(inValue->getData(), groundTruth, - seqStartPosition, - subSeqStartPosition, - hasSubseq, + hasSubseq ? subSeqStartPosition : seqStartPosition, beamSize); for (auto useGpu : {false, true}) { TestConfig config; config.layerConfig.set_type("kmax_seq_score"); config.layerConfig.set_beam_size(beamSize); - config.inputDefs.push_back( - {hasSubseq ? INPUT_HASSUB_SEQUENCE_DATA : INPUT_SEQUENCE_DATA, - "layer_0", - 1, - 0}); + + if (hasSubseq) { + config.inputDefs.push_back({INPUT_SELF_DEFINE_DATA, + "scores", + inValue, + seqStartPosition, + subSeqStartPosition}); + } else { + config.inputDefs.push_back( + {INPUT_SELF_DEFINE_DATA, "scores", inValue, seqStartPosition}); + } config.layerConfig.add_inputs(); // data layer initialize std::vector dataLayers; LayerMap layerMap; vector datas; - initDataLayer(config, - &dataLayers, - &datas, - &layerMap, - "kmax_seq_score", - 100, - false, - useGpu); + initDataLayer( + config, + &dataLayers, + &datas, + &layerMap, + "kmax_seq_score", + 100 /* actually this parameter is unused in self-defined input*/, + false, + useGpu); // test layer initialize std::vector parameters; LayerPtr kmaxSeqScoreLayer; + FLAGS_use_gpu = useGpu; initTestLayer(config, &layerMap, ¶meters, &kmaxSeqScoreLayer); kmaxSeqScoreLayer->forward(PASS_TRAIN); + + const MatrixPtr outValue = kmaxSeqScoreLayer->getOutputValue(); + CHECK_EQ(outValue->getHeight(), + hasSubseq ? subSeqStartPosition.size() - 1 + : seqStartPosition.size() - 1); + CHECK_EQ(outValue->getWidth(), beamSize); + checkLayerOut(groundTruth, outValue->getData(), beamSize); } } } @@ -141,6 +155,6 @@ int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); FLAGS_thread_local_rand_use_global_seed = true; - srand(1); + srand((size_t)(time(NULL))); return RUN_ALL_TESTS(); } diff --git a/paddle/math/BaseMatrix.cu b/paddle/math/BaseMatrix.cu index 6db5965789b3750f46731f157167150583130d0a..ba2b47d6cc6961a380b7db2781b4d214dea829db 100644 --- a/paddle/math/BaseMatrix.cu +++ b/paddle/math/BaseMatrix.cu @@ -442,7 +442,8 @@ DEFINE_MATRIX_UNARY_PARAMETER_OP(Clip, TWO_PARAMETER, template void BaseMatrixT::clip(T p1, T p2) { applyUnary(unary::Clip(p1, p2)); } -DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, a = b < p1 ? 0 : (b > p2 ? 0 : 1)); +DEFINE_MATRIX_BINARY_PARAMETER_OP(ClipDerivative, TWO_PARAMETER, + a = b < p1 ? 0 : (b > p2 ? 0 : 1)); template void BaseMatrixT::clipDerivative(BaseMatrixT& b, T p1, T p2) { applyBinary(binary::ClipDerivative(p1, p2), b); diff --git a/paddle/memory/detail/buddy_allocator.h b/paddle/memory/detail/buddy_allocator.h index 4fa3fb0ee5f826d2b084c0ba184c505aee3acc48..9c41378483993101a098fc4ad1068c1ef908e566 100644 --- a/paddle/memory/detail/buddy_allocator.h +++ b/paddle/memory/detail/buddy_allocator.h @@ -39,7 +39,7 @@ class BuddyAllocator { public: void* Alloc(size_t unaligned_size); - void Free(void*); + void Free(void* ptr); size_t Used(); public: diff --git a/paddle/memory/detail/meta_cache.h b/paddle/memory/detail/meta_cache.h index ca0789779e273fb71c3d6282c0a921cda2d776cc..cf5815644284c23a1d2abc904f8c5053ce107a72 100644 --- a/paddle/memory/detail/meta_cache.h +++ b/paddle/memory/detail/meta_cache.h @@ -33,17 +33,17 @@ namespace detail { */ class MetadataCache { public: - MetadataCache(bool uses_gpu); + explicit MetadataCache(bool uses_gpu); public: /*! \brief Load the associated metadata for the specified memory block. */ - Metadata load(const MemoryBlock*); + Metadata load(const MemoryBlock* memory_block); /*! \brief Store the associated metadata for the specified memory block. */ - void store(MemoryBlock*, const Metadata&); + void store(MemoryBlock* memory_block, const Metadata& meta_data); /*! \brief Indicate that the specified metadata will no longer be used. */ - void invalidate(MemoryBlock*); + void invalidate(MemoryBlock* memory_block); public: MetadataCache(const MetadataCache&) = delete; diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 44f567caf9c19775f17988b5142b7693b41a126d..72351b9dfa63513713463bb47a3684f0dfd84ad3 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -68,7 +68,7 @@ class PODDeleter { static_assert(std::is_pod::value, "T must be POD"); public: - PODDeleter(Place place) : place_(place) {} + explicit PODDeleter(Place place) : place_(place) {} void operator()(T* ptr) { Free(place_, static_cast(ptr)); } private: diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index b910bee836ed488aeb34f28d0503b5efba396583..96c76e22e9814682008f2e6c7ae98e2599d391c2 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,6 +41,9 @@ function(op_library TARGET) endif() endfunction() +cc_library(net_op SRCS net_op.cc DEPS op_registry) +cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) + op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) @@ -59,11 +62,6 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(fc_op SRCS fc_op.cc - DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) - -op_library(recurrent_network_op - SRCS recurrent_network_op.cc - DEPS op_desc tensor net) -cc_test(recurrent_network_op_test - SRCS recurrent_network_op_test.cc - DEPS recurrent_network_op mul_op add_op) + DEPS mul_op rowwise_add_op sigmoid_op softmax_op net_op) +op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net_op) +cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index f961b37565f400b5c26844b9e7a3cff5e682340b..9bd08634da96c5595d6dd702ad9afafb94632b03 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/add_op.h" diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index a4ee407cae1ec4974ae5addd8626686d908eaf32..54d2231425293f6cfb3adc9cb34d903a75fcdcd0 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -28,9 +28,13 @@ public: output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device(context.GetEigenDevice()) = - framework::EigenVector::Flatten(*input0) + - framework::EigenVector::Flatten(*input1); + auto X = EigenVector::Flatten(*input0); + auto Y = EigenVector::Flatten(*input1); + auto Z = EigenVector::Flatten(*output); + + auto place = context.GetEigenDevice(); + + Z.device(place) = X + Y; } }; diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 926a0c616b957d8e542c1f3dee227a718fb29f07..2f453f8379ca7ce0612fed757719acb2d2cf0ad8 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -1,5 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); \ No newline at end of file + ops::OnehotCrossEntropyOpKernel); diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu index 55ad58f4f17cd4a3e737c01b001675d2690d273e..ed1068219c8fee8c6e8809f450a9d38c8226f317 100644 --- a/paddle/operators/fill_zeros_like_op.cu +++ b/paddle/operators/fill_zeros_like_op.cu @@ -1,6 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/framework/op_registry.h" #include "paddle/operators/fill_zeros_like_op.h" REGISTER_OP_GPU_KERNEL( fill_zeros_like, - paddle::operators::FillZerosLikeKernel); \ No newline at end of file + paddle::operators::FillZerosLikeKernel); diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index fe34d6ad4015620cac520146850e10563d4c50e0..78131b26808b183ee107313374493ae870f1b641 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -33,13 +33,23 @@ public: MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op"); + AddOutput("Out", "The output of mean op").IgnoreGradient(); AddComment("Mean Operator"); } }; +class MeanGradOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + ctx.Output("X" + GRAD_VAR_SUFFIX()) + ->Resize(ctx.Input("X")->dims()); + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); +REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); +REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mean_op.cu b/paddle/operators/mean_op.cu index 740157cbc57a64cafcf109186c630691620f542b..8b97b0154ccdc8c41a90f7580af829c5c8663b60 100644 --- a/paddle/operators/mean_op.cu +++ b/paddle/operators/mean_op.cu @@ -1,5 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/mean_op.h" REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel); +REGISTER_OP_GPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index 20f21105298a20b2083c35e3a5d8b8049b0752ca..e712dee6a785749e51be7b233e85dbf39c835218 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -27,8 +27,28 @@ public: output->mutable_data(context.GetPlace()); - EigenScalar::From(*output).device(context.GetEigenDevice()) = - EigenVector::Flatten(*input).mean(); + auto X = EigenVector::Flatten(*input); + auto y = EigenScalar::From(*output); + auto place = context.GetEigenDevice(); + + y.device(place) = X.mean(); + } +}; + +template +class MeanGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto OG = context.Input("Out" + OperatorBase::GRAD_VAR_SUFFIX()); + PADDLE_ENFORCE(framework::product(OG->dims()) == 1, + "Mean Gradient should be scalar"); + auto IG = context.Output("X" + OperatorBase::GRAD_VAR_SUFFIX()); + IG->mutable_data(context.GetPlace()); + + T ig_size = (T)framework::product(IG->dims()); + + EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = + EigenScalar::From(*OG) / ig_size; } }; diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index dc9236701627dc9335b844d2a82e18eb1f7dfd42..1dc04c4297daed7a7861a09cf6b99446c296ffa5 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -15,4 +15,4 @@ #define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" -REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); \ No newline at end of file +REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 1d0617ab8bd44297ee112c58542cf6487118e801..c7b78ad39045d25d73bfc2c930063c255a514864 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -26,13 +26,18 @@ public: Eigen::array, 1> dim_pair = { {Eigen::IndexPair(1, 0)}}; + auto input0 = context.Input("X"); + auto input1 = context.Input("Y"); auto output = context.Output(0); + output->mutable_data(context.GetPlace()); - EigenMatrix::From(*output).device(context.GetEigenDevice()) = - EigenMatrix::From(*context.Input("X")) - .contract(EigenMatrix::From(*context.Input("Y")), - dim_pair); + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto Z = EigenMatrix::From(*output); + auto place = context.GetEigenDevice(); + + Z.device(place) = X.contract(Y, dim_pair); } }; } // namespace operators diff --git a/paddle/framework/net.cc b/paddle/operators/net_op.cc similarity index 96% rename from paddle/framework/net.cc rename to paddle/operators/net_op.cc index 2cd378c6b21303d1a24206ba3010b0d035aaa766..fbc98e09923bda7f3baee04e02df9076247bff0b 100644 --- a/paddle/framework/net.cc +++ b/paddle/operators/net_op.cc @@ -14,11 +14,11 @@ limitations under the License. */ -#include "paddle/framework/net.h" +#include "paddle/operators/net_op.h" #include "paddle/framework/op_registry.h" namespace paddle { -namespace framework { +namespace operators { void NetOp::CompleteAddOp(bool calc) { add_op_done_ = true; @@ -74,5 +74,5 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/operators/net_op.h similarity index 89% rename from paddle/framework/net.h rename to paddle/operators/net_op.h index acf1a69da9fd8adce1bd89367c882eade052e725..13611e1ee83170db43e17d6088e4b04588ce6255 100644 --- a/paddle/framework/net.h +++ b/paddle/operators/net_op.h @@ -14,15 +14,17 @@ limitations under the License. */ #pragma once -#include -#include +#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" #include "paddle/framework/scope.h" +#include "paddle/operators/type_alias.h" #include "paddle/platform/device_context.h" namespace paddle { -namespace framework { +namespace operators { + /** * @brief Network is also a type of Operator * @@ -37,13 +39,13 @@ namespace framework { * This is the base class of network, all the networks should implement the APIs * it defines. */ -class NetOp : public OperatorBase { - public: +class NetOp : public framework::OperatorBase { +public: /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - void InferShape(const Scope& scope) const override { + void InferShape(const framework::Scope& scope) const override { for (auto& op : ops_) { op->InferShape(scope); } @@ -56,7 +58,7 @@ class NetOp : public OperatorBase { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - void Run(const Scope& scope, + void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const override { for (auto& op : ops_) { op->Run(scope, dev_ctx); @@ -88,7 +90,7 @@ class NetOp : public OperatorBase { std::vector> ops_; - private: +private: bool add_op_done_{false}; template @@ -97,5 +99,5 @@ class NetOp : public OperatorBase { } }; -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/framework/net_design.md b/paddle/operators/net_op_design.md similarity index 100% rename from paddle/framework/net_design.md rename to paddle/operators/net_op_design.md diff --git a/paddle/framework/net_op_test.cc b/paddle/operators/net_op_test.cc similarity index 91% rename from paddle/framework/net_op_test.cc rename to paddle/operators/net_op_test.cc index f32e456e5d142bf8203f9ec03e8059772c4f5c99..18c5c60eb43250c23e2819a3c79ab8a96fec103e 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -1,16 +1,18 @@ +#include "paddle/operators/net_op.h" + #include -#include -#include -#include + +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" namespace paddle { -namespace framework { +namespace operators { static int infer_shape_cnt = 0; static int run_cnt = 0; class TestOp : public OperatorBase { - public: +public: void InferShape(const framework::Scope& scope) const override { ++infer_shape_cnt; } @@ -21,7 +23,7 @@ class TestOp : public OperatorBase { }; class EmptyOp : public OperatorBase { - public: +public: void InferShape(const Scope& scope) const override {} void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const override {} @@ -73,7 +75,7 @@ TEST(OpKernel, all) { ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); } -TEST(Net, insert_op) { +TEST(NetOp, insert_op) { NetOp net; auto op1 = std::make_shared(); op1->inputs_ = {"x", "w1", "b1"}; @@ -85,5 +87,5 @@ TEST(Net, insert_op) { ASSERT_EQ(3UL, net.ops_.size()); } -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/operators/recurrent_network_op.cc b/paddle/operators/recurrent_op.cc similarity index 67% rename from paddle/operators/recurrent_network_op.cc rename to paddle/operators/recurrent_op.cc index 60d065fc4789f76370840328870165579aa73b67..aeb95569b728f53b288a0c9a28220be8b5f7aaa4 100644 --- a/paddle/operators/recurrent_network_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -12,14 +12,14 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/recurrent_network_op.h" +#include "paddle/operators/recurrent_op.h" #include #include #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" #include "paddle/platform/enforce.h" namespace paddle { @@ -29,11 +29,15 @@ namespace rnn { void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, - const size_t seq_len) { + const size_t seq_len, + bool infer_shape_mode) { PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); for (size_t i = 0; i < inlinks.size(); ++i) { - Tensor* input = - step_scopes[0]->FindVar(inlinks[i].external)->GetMutable(); + auto input_var = step_scopes[0]->FindVar(inlinks[i].external); + PADDLE_ENFORCE(input_var != nullptr, + "input link [%s] is not in scope.", + inlinks[i].external); + Tensor* input = input_var->GetMutable(); DDim dims = input->dims(); PADDLE_ENFORCE(static_cast(dims[0]) == seq_len, "all the inlinks must have same length"); @@ -41,7 +45,9 @@ void SegmentInputs(const std::vector& step_scopes, for (size_t j = 0; j < seq_len; j++) { Tensor* step_input = step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable(); - *step_input = input->Slice(j, j + 1); + if (!infer_shape_mode) { + *step_input = input->Slice(j, j + 1); + } step_input->Resize(step_dims); } } @@ -49,36 +55,41 @@ void SegmentInputs(const std::vector& step_scopes, void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len) { + const size_t seq_len, + bool infer_shape_mode) { for (size_t i = 0; i < outlinks.size(); i++) { - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); - - // TODO(qingiqng) remove following code after adding - // InferShape in RecurrentGradientOp - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); - std::vector dims_vec = vectorize(step_dims); - dims_vec.insert(dims_vec.begin(), seq_len); - output->mutable_data(make_ddim(dims_vec), platform::CPUPlace()); - - for (size_t j = 0; j < seq_len; j++) { - Tensor* step_output = - step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); - // TODO(luotao02) data type and platform::DeviceContext() should set - // correctly - (output->Slice(j, j + 1)) - .CopyFrom(*step_output, platform::CPUPlace()); + auto output_var = step_scopes[0]->FindVar(outlinks[i].external); + PADDLE_ENFORCE(output_var != nullptr, + "output link [%s] is not in scope.", + outlinks[i].external); + Tensor* output = output_var->GetMutable(); + if (infer_shape_mode) { + DDim step_dims = step_scopes[0] + ->FindVar(outlinks[i].internal) + ->GetMutable() + ->dims(); + std::vector dims_vec = vectorize(step_dims); + dims_vec.insert(dims_vec.begin(), seq_len); + output->Resize(make_ddim(dims_vec)); + } else { + output->mutable_data(platform::CPUPlace()); + for (size_t j = 0; j < seq_len; j++) { + Tensor* step_output = + step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable(); + // TODO(luotao02) data type and platform::DeviceContext() should set + // correctly + (output->Slice(j, j + 1)) + .CopyFrom(*step_output, platform::CPUPlace()); + } } } } void LinkMemories(const std::vector& scopes, const std::vector& memories, - size_t step_id, - int offset) { + const size_t step_id, + const int offset, + bool infer_shape_mode) { PADDLE_ENFORCE(step_id < scopes.size(), "step [%d] is out of range of step scopes' size [%d]", step_id, @@ -95,18 +106,13 @@ void LinkMemories(const std::vector& scopes, auto scope = scopes[step_id]; auto linked_scope = scopes[step_id + offset]; for (auto& attr : memories) { - auto mem = scope->NewVar(attr.pre_var)->GetMutable(); - // maybe share variable is better? + auto mem = scope->FindVar(attr.pre_var)->GetMutable(); auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable(); - mem->ShareDataWith(*linked_mem); - - // TODO(qingqing) remove following code - // the memory of current step should be allocated in step net - auto m = scope->NewVar(attr.var)->GetMutable(); - // for unit test, as addOp and mulOp are null currently, if not - // mutable_data, mem.data() in output will be error. We will - // remove this line after merge the correct addOp and mulOp. - m->mutable_data(mem->dims(), platform::CPUPlace()); + if (infer_shape_mode) { + mem->Resize(linked_mem->dims()); + } else { + mem->ShareDataWith(*linked_mem); + } } } @@ -175,60 +181,39 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { ->dims()[0]; CreateScopes(scope); auto step_scopes = GetStepScopes(scope); - - // SegmentInputs is called in InferShape. The input must hold memory in - // SegmentInputs. But the other op only set dimension for the output in - // InferShape. That's a problem. Wether the RNN op needs InferShape or not? - // Wether the following functions (SegmentInputs, InitMemories, ...) need - // to rewrite for RNN op? - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - - InitMemories(step_scopes[0]); - - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "stepnet [%s] is not in scope.", - arg_->step_net); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); + InitMemories(step_scopes[0], true /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); - // If the InferShape is called in OperatorBase's run function, - // the rnn op only needs to do InferShape for the first time step for (size_t i = 0; i < seq_len_; i++) { if (i > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, i, -1); + rnn::LinkMemories( + step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } net->GetMutable()->InferShape(*step_scopes[i]); } - - auto outlinks = arg_->outlinks; - for (size_t i = 0; i < outlinks.size(); i++) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); - std::vector dims_vec = vectorize(step_dims); - // now only support fixed length - dims_vec.insert(dims_vec.begin(), seq_len_); - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); - output->Resize(make_ddim(dims_vec)); - } + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); } void RecurrentAlgorithm::Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); - + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); + InitMemories(step_scopes[0], false /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); + for (size_t step_id = 0; step_id < seq_len_; step_id++) { - // the link memory is done in InferShape - // maybe remove following code after testing if (step_id > 0) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); } - - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); } void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { @@ -244,18 +229,19 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { // Now all variables in scope must be created outside of op. auto net_op = scope.FindVar(arg_->step_net)->GetMutable(); for (auto& input : net_op->inputs_) { + // the weight are located in parent scope if (!step_scope.FindVar(input)) step_scope.NewVar(input); } for (auto& output : net_op->outputs_) { step_scope.NewVar(output); } - step_scopes->emplace_back(&step_scope); } } } -void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { +void RecurrentAlgorithm::InitMemories(Scope* step_scope, + bool infer_shape_mode) const { for (auto& attr : arg_->memories) { Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable(); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, @@ -263,13 +249,11 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope) const { attr.var, attr.boot_var); Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable(); - pre_mem->ShareDataWith(*boot_mem); - - // TODO(qingqing) remove following code - // the memory of current step should be allocated in step net - // here for unit test - auto cur_step_mem = step_scope->NewVar(attr.var)->GetMutable(); - cur_step_mem->mutable_data(boot_mem->dims(), platform::CPUPlace()); + if (infer_shape_mode) { + pre_mem->Resize(boot_mem->dims()); + } else { + pre_mem->ShareDataWith(*boot_mem); + } } } @@ -307,13 +291,14 @@ public: : OpProtoAndCheckerMaker(proto, op_checker) { const auto& name = RecurrentOp::kArgName; // inputs and outputs stored in proto - AddInput(name.inlinks, "the input that need to be segmented for each step.") + AddInput(name.inlinks, + "the inputs that need to be segmented for each step.") .SetMultiple(); AddInput(name.boot_memories, "variables to initialize memories.") .SetMultiple(); AddInput(name.step_net, "network shared by all steps."); - AddOutput(name.outlinks, "the output that need to concated for all steps.") + AddOutput(name.outlinks, "the outputs that need to concated for all steps.") .SetMultiple(); AddOutput(name.step_scopes, "step scopes"); @@ -331,34 +316,39 @@ public: void RecurrentGradientAlgorithm::Run( const Scope& scope, const platform::DeviceContext& dev_ctx) const { auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "step net is not in scope."); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); } - LinkBootMemoryGradients(step_scopes[0]); - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); + LinkBootMemoryGradients(step_scopes[0], false); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); } void RecurrentGradientAlgorithm::LinkBootMemoryGradients( - Scope* step_scope) const { + Scope* step_scope, bool infer_shape_mode) const { for (auto& attr : arg_->memories) { - Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); - PADDLE_ENFORCE(mem_grad != nullptr, - "boot_tensor should be retrieved before"); + PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr, + "memory variable [%s] does not exists", + attr.var); PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr, - "memory [%s]'s boot variable [%s] not exists", - attr.var, + "boot variable [%s] does not exists", attr.boot_var); + Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable(); Tensor* boot_mem_grad = step_scope->NewVar(attr.boot_var)->GetMutable(); - boot_mem_grad->ShareDataWith(*mem_grad); + if (infer_shape_mode) { + boot_mem_grad->Resize(mem_grad->dims()); + } else { + boot_mem_grad->ShareDataWith(*mem_grad); + } } } @@ -367,34 +357,20 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { ->GetMutable() ->dims()[0]; auto step_scopes = GetStepScopes(scope); - rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_); - - PADDLE_ENFORCE(scope.FindVar(arg_->step_net) != nullptr, - "step net is not in scope."); + rnn::SegmentInputs( + step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); Variable* net = scope.FindVar(arg_->step_net); PADDLE_ENFORCE(net != nullptr, "failed to get step net"); - for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { - rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); + rnn::LinkMemories( + step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } net->GetMutable()->InferShape(*step_scopes[step_id]); } - - auto outlinks = arg_->outlinks; - for (size_t i = 0; i < outlinks.size(); i++) { - DDim step_dims = step_scopes[0] - ->FindVar(outlinks[i].internal) - ->GetMutable() - ->dims(); - std::vector dims_vec = vectorize(step_dims); - // now only support fixed length - dims_vec.insert(dims_vec.begin(), seq_len_); - Tensor* output = - step_scopes[0]->FindVar(outlinks[i].external)->GetMutable(); - output->Resize(make_ddim(dims_vec)); - } - LinkBootMemoryGradients(step_scopes[0]); + rnn::ConcatOutputs( + step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); + LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/); } void RecurrentGradientOp::Init() { diff --git a/paddle/operators/recurrent_network_op.h b/paddle/operators/recurrent_op.h similarity index 85% rename from paddle/operators/recurrent_network_op.h rename to paddle/operators/recurrent_op.h index d57a1a2e51cbed22549ab6ebce79223e2d4e3bcf..35e6d9d50dd04048da7ffb384014d5909cd659a4 100644 --- a/paddle/operators/recurrent_network_op.h +++ b/paddle/operators/recurrent_op.h @@ -19,7 +19,7 @@ namespace paddle { namespace operators { -using namespace paddle::framework; +using namespace paddle::framework; // NOLINT namespace rnn { @@ -72,26 +72,29 @@ struct ArgumentName { */ void SegmentInputs(const std::vector& step_scopes, const std::vector& inlinks, - const size_t seq_len); + const size_t seq_len, + bool infer_shape_mode); /** * Process outputs of step nets and merge to variables. */ void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len); + const size_t seq_len, + bool infer_shape_mode); void LinkMemories(const std::vector& step_scopes, const std::vector& memories, - size_t step_id, - int offset); + const size_t step_id, + const int offset, + bool infer_shape_mode); void InitArgument(const ArgumentName& name, Argument* arg); }; // namespace rnn // The sequence format in RecurrentOp is Tensor now. -// TODO: +// TODO(Yan Chunwei): // 1. No-padding computing for sequences with indifinite length in one batch. // 2. Hierarchical RNN for sequence with sub-sequence. // 3. Internal Memory. @@ -122,7 +125,7 @@ protected: return *scope.FindVar(arg_->step_scopes)->GetMutable>(); } - void InitMemories(Scope* step_scopes) const; + void InitMemories(Scope* step_scopes, bool infer_shape_mode) const; private: std::unique_ptr arg_; @@ -145,7 +148,7 @@ public: void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const; - void LinkBootMemoryGradients(Scope* step_scopes) const; + void LinkBootMemoryGradients(Scope* step_scopes, bool infer_shape_mode) const; /** * InferShape must be called before Run. @@ -169,12 +172,10 @@ public: /** * InferShape must be called before Run. */ - virtual void InferShape(const Scope& scope) const override { - alg_.InferShape(scope); - } + void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } - virtual void Run(const Scope& scope, - const platform::DeviceContext& dev_ctx) const override { + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); } @@ -191,12 +192,10 @@ public: /** * InferShape must be called before Run. */ - virtual void InferShape(const Scope& scope) const override { - alg_.InferShape(scope); - } + void InferShape(const Scope& scope) const override { alg_.InferShape(scope); } - virtual void Run(const Scope& scope, - const platform::DeviceContext& dev_ctx) const override { + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override { alg_.Run(scope, dev_ctx); } diff --git a/paddle/operators/recurrent_network_op_test.cc b/paddle/operators/recurrent_op_test.cc similarity index 90% rename from paddle/operators/recurrent_network_op_test.cc rename to paddle/operators/recurrent_op_test.cc index b0e61fbee611744adb85b498b1c3540f059afc8c..08a6d9fe5681fdea180de2e9361734ade8564775 100644 --- a/paddle/operators/recurrent_network_op_test.cc +++ b/paddle/operators/recurrent_op_test.cc @@ -11,14 +11,15 @@ limitations under the License. */ +#include "paddle/operators/recurrent_op.h" + #include #include -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" #include "paddle/framework/tensor.h" -#include "paddle/operators/recurrent_network_op.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace operators { @@ -55,7 +56,7 @@ protected: w->GetMutable()->mutable_data( make_ddim(std::vector{30, 30}), platform::CPUPlace()); - for (auto boot : std::vector{"x_boot", "h_boot"}) { + for (auto boot : std::vector{"h_boot"}) { LOG(INFO) << "create global variable " << boot; Variable* h_boot = scope_.NewVar(boot); h_boot->GetMutable()->mutable_data( @@ -79,7 +80,6 @@ protected: op_desc.add_inputs("x0"); op_desc.add_inputs("x1"); // boot_memories 3 - op_desc.add_inputs("x_boot"); op_desc.add_inputs("h_boot"); // step net 5 op_desc.add_inputs("step_net"); @@ -91,7 +91,7 @@ protected: auto _input_format = std::vector{ 0, // in_link 3, // memories - 5 // step_net + 4 // step_net }; auto input_format = op_desc.add_attrs(); input_format->set_name("input_format"); @@ -129,12 +129,11 @@ protected: inlink_alias->add_strings(item); } // pre memories - for (const auto& item : - std::vector{"rnn/x@pre", "rnn/h@pre"}) { + for (const auto& item : std::vector{"rnn/h@pre"}) { pre_memories->add_strings(item); } // memories - for (const auto& item : std::vector{"rnn/x", "rnn/h"}) { + for (const auto& item : std::vector{"rnn/h"}) { memories->add_strings(item); } // output alias @@ -151,14 +150,11 @@ protected: LOG(INFO) << "create variable step_net"; Variable* var = scope_.NewVar("step_net"); auto net = var->GetMutable(); - // rnn/s is net's input or output? - net->inputs_ = {"rnn/h@pre", "rnn/w", "rnn/x"}; - net->inputs_ = {"rnn/s", "rnn/h"}; net->AddOp( OpRegistry::CreateOp("mul", {"rnn/h@pre", "rnn/w"}, {"rnn/s"}, {})); net->AddOp( - OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); + OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {})); net->CompleteAddOp(); } @@ -297,7 +293,10 @@ protected: inlink.internal = "rnn/x"; auto step_scopes = scope_.FindVar("step_scopes")->GetMutable>(); - rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10); + rnn::SegmentInputs(*step_scopes, + std::vector{inlink}, + 10, + true /*infer_shape_mode*/); } void LinkeMemories() { @@ -311,7 +310,8 @@ protected: auto step_scopes = scope_.FindVar("step_scopes")->GetMutable>(); for (int i = 1; i < 10; ++i) { - rnn::LinkMemories(*step_scopes, memories, i, -1); + rnn::LinkMemories( + *step_scopes, memories, i, -1, true /*infer_shape_mode*/); } } @@ -333,14 +333,14 @@ TEST(RecurrentOp, LinkMemories) { using namespace paddle::operators; // create and init step scopes - int len = 10; + size_t len = 10; std::vector step_scopes; - for (int i = 0; i < len; ++i) { + for (size_t i = 0; i < len; ++i) { auto scope = new Scope(); scope->NewVar("pre_h"); auto tensor = scope->NewVar("h")->GetMutable(); float* data = tensor->mutable_data({15, 20}, CPUPlace()); - for (int j = 0; j < 15 * 20; ++j) { + for (size_t j = 0; j < 15 * 20; ++j) { data[j] = rand() * (1. / (double)RAND_MAX); } step_scopes.push_back(scope); @@ -354,24 +354,24 @@ TEST(RecurrentOp, LinkMemories) { std::vector memories; memories.push_back(mem_attr); - for (int i = 1; i < len; ++i) { - rnn::LinkMemories(step_scopes, memories, i, -1); + for (size_t i = 1; i < len; ++i) { + rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/); } // check - for (int i = 0; i < len - 1; ++i) { + for (size_t i = 0; i < len - 1; ++i) { const float* a = step_scopes[i]->FindVar("h")->GetMutable()->data(); const float* b = step_scopes[i + 1] ->FindVar("pre_h") ->GetMutable() ->data(); - for (size_t i = 0; i < 15 * 20; ++i) { - ASSERT_FLOAT_EQ(a[i], b[i]); + for (size_t j = 0; j < 15 * 20; ++j) { + ASSERT_FLOAT_EQ(a[j], b[j]); } } for (int i = len - 2; i >= 0; --i) { - rnn::LinkMemories(step_scopes, memories, i, 1); + rnn::LinkMemories(step_scopes, memories, i, 1, false /*infer_shape_mode*/); } // check for (int i = len - 2; i >= 0; --i) { @@ -379,8 +379,8 @@ TEST(RecurrentOp, LinkMemories) { step_scopes[i]->FindVar("pre_h")->GetMutable()->data(); const float* b = step_scopes[i + 1]->FindVar("h")->GetMutable()->data(); - for (size_t i = 0; i < 15 * 20; ++i) { - ASSERT_FLOAT_EQ(a[i], b[i]); + for (size_t j = 0; j < 15 * 20; ++j) { + ASSERT_FLOAT_EQ(a[j], b[j]); } } @@ -391,9 +391,3 @@ TEST(RecurrentOp, LinkMemories) { USE_OP(add_two); USE_OP(mul); - -// int main() { -// //! TODO(yuyang18): Temporary disable this unit-test because implementation -// //! error. -// return 0; -//} \ No newline at end of file diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 82338ceccc06653791b26472e18d804f62735649..f76faa0a3a93a1ac277a1d1aa83c3fa6c3944648 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/rowwise_add_op.h" diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index d79258cbf13c699cfb2afaee229cf96a3e377b5e..72629ccfbb8bc8ec53045289bd985c721c62fa10 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -1,4 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/sgd_op.h" -REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel); \ No newline at end of file +REGISTER_OP_GPU_KERNEL(sgd, ops::SGDOpKernel); diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index d8ddbac573ffae8bd5d4b53041ed4dd35bb5ae0c..0c3a240f9a4a5fc7bc4898e82786810cee2f7010 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -29,8 +29,12 @@ public: param_out->mutable_data(ctx.GetPlace()); - EigenVector::Flatten(*param_out).device(ctx.GetEigenDevice()) = - EigenVector::Flatten(*param) - lr * EigenVector::Flatten(*grad); + auto p = EigenVector::Flatten(*param); + auto g = EigenVector::Flatten(*grad); + auto o = EigenVector::Flatten(*param_out); + auto place = ctx.GetEigenDevice(); + + o.device(place) = p - lr * g; } }; diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index c9d11a2e1f9dcc563765c9e8cc1bae6beff57f18..2123b17e4b5e90c22c2d6e9177f2a8956f8a4ac9 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/operators/sigmoid_op.h" diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index f518ddcf3b2381f0b8ceec3eee260f033a847263..1412e4398440c8e946d3ab434a50e978079637ab 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -27,8 +27,11 @@ public: auto output = context.Output(0); output->mutable_data(context.GetPlace()); - EigenVector::Flatten(*output).device(context.GetEigenDevice()) = - 1.0 / (1.0 + (-1.0 * EigenVector::Flatten(*input)).exp()); + auto X = EigenVector::Flatten(*input); + auto Y = EigenVector::Flatten(*output); + auto place = context.GetEigenDevice(); + + Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 5b59fad7d5f9729b0862f8cd78cb32f94f87f513..5cbb96ab754467ea6ddab9380ca25987c9376980 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -1,16 +1,17 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ #include "paddle/operators/softmax_op.h" namespace paddle { @@ -19,12 +20,13 @@ namespace operators { class SoftmaxOp : public OperatorWithKernel { protected: void InferShape(const InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax"); - PADDLE_ENFORCE(ctx.Input(0)->dims().size() == 2, + PADDLE_ENFORCE(ctx.InputSize() == 1UL, + "Only one input is need for softmax"); + PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, "The input of softmax op must be matrix"); - PADDLE_ENFORCE(ctx.OutputSize() == 1, + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, "Only one output is need for softmax"); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -40,10 +42,19 @@ public: class SoftmaxOpGrad : public OperatorWithKernel { protected: - void InferShape(const InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "SoftmaxOpGrad"; - return ""; + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 3UL, + "Input of SoftmaxOpGrad should be 3, X, Y, YG"); + PADDLE_ENFORCE(ctx.OutputSize() == 1UL, + "Output of SoftmaxOpGrad should be 1"); + PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx.InputVar(GRAD_VAR_NAME("Y")) != nullptr, + "Input(Y@GRAD) should not be null"); + PADDLE_ENFORCE(ctx.Input("Y")->dims() == + ctx.Input(GRAD_VAR_NAME("Y"))->dims(), + "the shape of Input(0) and Input(1) should be the same"); + ctx.Output(GRAD_VAR_NAME("X")) + ->Resize(ctx.Input("Y")->dims()); } }; @@ -51,5 +62,7 @@ protected: } // namespace paddle REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); -REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); +REGISTER_OP_CPU_KERNEL(softmax_grad, + ops::SoftmaxGradKernel); diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index ddf8f6e913ccf450185f377f531bf978f69ed1fc..b79228580a7ea0f70b62eb2dc7a61cf85bc0b5fb 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -1,5 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #define EIGEN_USE_GPU #include "paddle/framework/op_registry.h" #include "paddle/operators/softmax_op.h" REGISTER_OP_GPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_OP_GPU_KERNEL(softmax_grad, + ops::SoftmaxGradKernel); diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 75c5197697dada58e09f4cda41cea13af56e79a3..13e74a79077982e9fba5d90f40986e699c1ed897 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -1,19 +1,22 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once +#include "paddle/framework/ddim.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/tensor.h" #include "paddle/operators/type_alias.h" namespace paddle { @@ -23,8 +26,8 @@ template class SoftmaxKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - auto input = context.Input(0); - auto output = context.Output(0); + auto input = context.Input("X"); + auto output = context.Output("Y"); output->mutable_data(context.GetPlace()); auto logits = EigenMatrix::From(*input); @@ -57,5 +60,38 @@ public: .broadcast(one_by_class)); } }; + +template +class SoftmaxGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + std::shared_ptr scale_ = std::make_shared(); + + auto Y = context.Input("Y"); + auto dY = context.Input(OperatorBase::GRAD_VAR_NAME("Y")); + auto dX = context.Output(OperatorBase::GRAD_VAR_NAME("X")); + dX->mutable_data(context.GetPlace()); + + const int batch_size = Y->dims()[0]; + const int class_num = Y->dims()[1]; + + Eigen::DSizes along_class(1); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, class_num); + + auto Y_eigen = EigenMatrix::From(*Y); + auto dY_eigen = EigenMatrix::From(*dY); + auto dX_eigen = EigenMatrix::From(*dX); + auto place = context.GetEigenDevice(); + + auto dot = (Y_eigen * dY_eigen) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/operators/type_alias.h b/paddle/operators/type_alias.h index 93b62cddc819e0d1fd48323e474a294ff0d327e1..931740e150946a939b8656be5a30185c6ee1cb8f 100644 --- a/paddle/operators/type_alias.h +++ b/paddle/operators/type_alias.h @@ -15,13 +15,14 @@ #pragma once #include "paddle/framework/eigen.h" -#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace operators { using OpKernel = framework::OpKernel; +using OperatorBase = framework::OperatorBase; using InferShapeContext = framework::InferShapeContext; using ExecutionContext = framework::ExecutionContext; using Variable = framework::Variable; @@ -43,14 +44,16 @@ template using EigenTensor = framework::EigenTensor; using Tensor = framework::Tensor; +using Scope = framework::Scope; using OperatorWithKernel = framework::OperatorWithKernel; +using OperatorBase = framework::OperatorBase; using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker; using OpProto = framework::OpProto; using OpAttrChecker = framework::OpAttrChecker; using CPUPlace = platform::CPUPlace; using GPUPlace = platform::GPUPlace; -using NetOp = framework::NetOp; using OpRegistry = framework::OpRegistry; + } // namespace operators } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 2038fafe2e15ec2631726643695ac6cbc317fed9..48b9f5dcb5cc578f9e70ed7abe076b66b68dc719 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -40,7 +40,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext { public: CPUDeviceContext(); - CPUDeviceContext(CPUPlace); + explicit CPUDeviceContext(CPUPlace); virtual ~CPUDeviceContext() {} Eigen::DefaultDevice* eigen_device() const; @@ -55,7 +55,7 @@ class CPUDeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext { public: - explicit CUDADeviceContext(GPUPlace); + CUDADeviceContext(GPUPlace); // NOLINT virtual ~CUDADeviceContext(); /*! \brief Wait for all operations completion in the stream. */ @@ -69,10 +69,10 @@ class CUDADeviceContext : public DeviceContext { // clang-format off /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle (); + cublasHandle_t cublas_handle(); /*! \brief Return cudnn handle in the device context. */ - cudnnHandle_t cudnn_handle (); + cudnnHandle_t cudnn_handle(); /*! \brief Return curand handle in the device context. */ curandGenerator_t curand_generator(); diff --git a/paddle/platform/dynload/cublas.cc b/paddle/platform/dynload/cublas.cc index 4e3dfdaefb2348346e8f917b1f6c758bf6d91a1a..9cd2a1f565526f8dc45932ba6168f4e25c6ad238 100644 --- a/paddle/platform/dynload/cublas.cc +++ b/paddle/platform/dynload/cublas.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { diff --git a/paddle/platform/dynload/cudnn.cc b/paddle/platform/dynload/cudnn.cc index 8b5e15b5efcdae6a1eed09f002eb2f4f2163035f..d3e4cb567d71b987724366b6a0896f5df0eb6055 100644 --- a/paddle/platform/dynload/cudnn.cc +++ b/paddle/platform/dynload/cudnn.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { @@ -25,4 +39,4 @@ CUDNN_DNN_ROUTINE_EACH_R5(DEFINE_WRAP); } // namespace dynload } // namespace platform -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/platform/dynload/curand.cc b/paddle/platform/dynload/curand.cc index 5c1fab992c98569d4a95b6e699d97d428511e48e..d05dd88126bfee7278e553710a717b8f2eb02ae0 100644 --- a/paddle/platform/dynload/curand.cc +++ b/paddle/platform/dynload/curand.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { @@ -10,6 +24,7 @@ void *curand_dso_handle; #define DEFINE_WRAP(__name) DynLoad__##__name __name CURAND_RAND_ROUTINE_EACH(DEFINE_WRAP); -} -} -} \ No newline at end of file + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 7cead183884bc9379355cd931921b40d6c11ce90..a37ad38a8fb030192fa4c871106c6eb54816768a 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -32,7 +32,7 @@ struct CPUPlace { struct GPUPlace { GPUPlace() : GPUPlace(0) {} - GPUPlace(int d) : device(d) {} + GPUPlace(int d) : device(d) {} // NOLINT // needed for variant equality comparison inline bool operator==(const GPUPlace &o) const { return device == o.device; } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index ac12b504b5cbee778c7c0a74a84a7729f210e01e..29dd0ded0ac75893da7e244d92725cd5e285efce 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -6,4 +6,4 @@ cc_library(paddle_pybind SHARED add_op mean_op cross_entropy_op - recurrent_network_op) + recurrent_op) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 69ae0ea2d72c199a8e17c0595693e5e0b2f79ee1..8de0e608c1f482e4553c07ff7ffd572d65a772aa 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -69,7 +69,7 @@ cat <