diff --git a/CMakeLists.txt b/CMakeLists.txt index 2a6b0a20e441676c85c9ed8f8ad1a6e7abdf1ea8..c7d743e193e7d32dbc0b56f3bcb05b6c61f85f1d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) @@ -74,6 +76,10 @@ if(ANDROID) "Disable PYTHON when cross-compiling for Android" FORCE) set(WITH_RDMA OFF CACHE STRING "Disable RDMA when cross-compiling for Android" FORCE) + set(WITH_MKLDNN OFF CACHE STRING + "Disable MKLDNN when cross-compiling for Android" FORCE) + set(WITH_MKLML OFF CACHE STRING + "Disable MKLML package when cross-compiling for Android" FORCE) endif(ANDROID) set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING @@ -87,6 +93,7 @@ endif() ######################################################################################## +include(external/mklml) # download mklml package include(external/zlib) # download, build, install zlib include(external/gflags) # download, build, install gflags include(external/glog) # download, build, install glog @@ -94,6 +101,7 @@ include(external/gtest) # download, build, install gtest include(external/protobuf) # download, build, install protobuf include(external/python) # download, build, install python include(external/openblas) # download, build, install openblas +include(external/mkldnn) # download, build, install mkldnn include(external/swig) # download, build, install swig include(external/warpctc) # download, build, install warpctc include(external/any) # download libn::any @@ -135,6 +143,10 @@ if(WITH_GPU) endif(NOT WITH_DSO) endif(WITH_GPU) +if(WITH_MKLDNN) + list(APPEND EXTERNAL_LIBS ${MKLDNN_LIBRARY} ${MKLDNN_IOMP_LIB}) +endif() + if(USE_NNPACK) include(external/nnpack) list(APPEND EXTERNAL_LIBS ${NNPACK_LIBS}) diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 913f711afff3b8f9f77b8da978a3b9e7165d0077..854066fd1d205c337fbdbe08997d88251095c799 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -15,23 +15,44 @@ set(CBLAS_FOUND OFF) -## Find MKL First. -set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs") -set(MKL_ROOT ${INTEL_ROOT}/mkl CACHE PATH "Folder contains MKL") +## Find MKLML First. +if(WITH_MKLML AND MKLML_INC_DIR AND MKLML_LIB) + set(CBLAS_FOUND ON) + set(CBLAS_PROVIDER MKLML) + set(CBLAS_INC_DIR ${MKLML_INC_DIR}) + set(CBLAS_LIBRARIES ${MKLML_LIB}) + + add_definitions(-DPADDLE_USE_MKLML) + add_definitions(-DLAPACK_FOUND) + + message(STATUS "Found cblas and lapack in MKLML " + "(include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + return() +endif() + +## Then find MKL. +set(INTEL_MKL_ROOT "/opt/intel/mkl" CACHE PATH "Folder contains intel mkl libs") +set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains env MKL") + +set(MKL_INCLUDE_SEARCH_PATHS + ${MKL_ROOT}/include + ${INTEL_MKL_ROOT}/include) +set(MKL_LIB_SEARCH_PATHS + ${MKL_ROOT}/lib + ${MKL_ROOT}/lib/intel64 + ${INTEL_MKL_ROOT}/lib + ${INTEL_MKL_ROOT}/lib/intel64) find_path(MKL_INC_DIR mkl.h PATHS - ${MKL_ROOT}/include) + ${MKL_INCLUDE_SEARCH_PATHS}) find_path(MKL_LAPACK_INC_DIR mkl_lapacke.h PATHS - ${MKL_ROOT}/include) + ${MKL_INCLUDE_SEARCH_PATHS}) find_library(MKL_CORE_LIB NAMES mkl_core PATHS - ${MKL_ROOT}/lib - ${MKL_ROOT}/lib/intel64) + ${MKL_LIB_SEARCH_PATHS}) find_library(MKL_SEQUENTIAL_LIB NAMES mkl_sequential PATHS - ${MKL_ROOT}/lib - ${MKL_ROOT}/lib/intel64) + ${MKL_LIB_SEARCH_PATHS}) find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS - ${MKL_ROOT}/lib - ${MKL_ROOT}/lib/intel64) + ${MKL_LIB_SEARCH_PATHS}) if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) set(CBLAS_FOUND ON) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 7afab5d5344b704a9329e313a81379032ba0cc97..69220e03fe8e337205f31cb1f45e3e19ae4f5d1e 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -67,6 +67,30 @@ else() include_directories(${CUDA_TOOLKIT_INCLUDE}) endif(NOT WITH_GPU) +if(WITH_MKLDNN) + add_definitions(-DPADDLE_USE_MKLDNN) + if (WITH_MKLML AND MKLDNN_IOMP_DIR) + message(STATUS "Enable Intel OpenMP at ${MKLDNN_IOMP_DIR}") + set(OPENMP_FLAGS "-fopenmp") + set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) + set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}") + else() + find_package(OpenMP) + if(OPENMP_FOUND) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + else() + message(WARNING "Can not find OpenMP." + "Some performance features in MKLDNN may not be available") + endif() + endif() + +endif(WITH_MKLDNN) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake index 77e06e983e9f8bfaf6320e3c67b85b692ed877fc..e3970073a1a0b946fa1db6642799719d7a9fcf4f 100644 --- a/cmake/external/gtest.cmake +++ b/cmake/external/gtest.cmake @@ -34,9 +34,15 @@ IF(WITH_TESTING) "${GTEST_INSTALL_DIR}/lib/libgtest_main.a" CACHE FILEPATH "gtest main libraries." FORCE) ENDIF(WIN32) + IF(WITH_MKLML) + # wait for mklml downloading completed + SET(GTEST_DEPENDS ${MKLML_PROJECT}) + ENDIF() + ExternalProject_Add( extern_gtest ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS ${GTEST_DEPENDS} GIT_REPOSITORY "https://github.com/google/googletest.git" GIT_TAG "release-1.8.0" PREFIX ${GTEST_SOURCES_DIR} diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake new file mode 100644 index 0000000000000000000000000000000000000000..eff15de73f23db6dea3a7b79006bfec90d712ae5 --- /dev/null +++ b/cmake/external/mkldnn.cmake @@ -0,0 +1,72 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IF(NOT ${WITH_MKLDNN}) + return() +ENDIF(NOT ${WITH_MKLDNN}) + +INCLUDE(ExternalProject) + +SET(MKLDNN_PROJECT "extern_mkldnn") +SET(MKLDNN_SOURCES_DIR ${THIRD_PARTY_PATH}/mkldnn) +SET(MKLDNN_INSTALL_ROOT ${CMAKE_INSTALL_PREFIX}) +IF(NOT "$ENV{HOME}" STREQUAL "/root") + SET(MKLDNN_INSTALL_ROOT "$ENV{HOME}") +ENDIF() + +SET(MKLDNN_INSTALL_DIR "${MKLDNN_INSTALL_ROOT}/opt/paddle/third_party/mkldnn") +SET(MKLDNN_INCLUDE_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) + +IF(WIN32) + MESSAGE(WARNING "It is not supported compiling with mkldnn in windows Paddle yet." + "Force WITH_MKLDNN=OFF") + SET(WITH_MKLDNN OFF) + return() +ELSE(WIN32) + SET(MKLDNN_LIBRARY "${MKLDNN_INSTALL_DIR}/lib/libmkldnn.so" CACHE FILEPATH "mkldnn library." FORCE) + MESSAGE(STATUS "Set ${MKLDNN_INSTALL_DIR}/lib to runtime path") + SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) + #SET(CMAKE_MACOSX_RPATH 1) # hold for MacOS + SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLDNN_INSTALL_DIR}/lib") +ENDIF(WIN32) + +INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIR}) + +IF(${CBLAS_PROVIDER} STREQUAL "MKLML") + SET(MKLDNN_DEPENDS ${MKLML_PROJECT}) + SET(MKLDNN_MKLROOT ${MKLML_ROOT}) + SET(MKLDNN_IOMP_LIB ${MKLML_IOMP_LIB}) + SET(MKLDNN_IOMP_DIR ${MKLML_LIB_DIR}) +ENDIF() + +ExternalProject_Add( + ${MKLDNN_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + DEPENDS ${MKLDNN_DEPENDS} + GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" + GIT_TAG "v0.9" + PREFIX ${MKLDNN_SOURCES_DIR} + CONFIGURE_COMMAND mkdir -p /build + BUILD_COMMAND cd /build + && cmake .. -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} -DMKLROOT=${MKLDNN_MKLROOT} + && $(MAKE) + INSTALL_COMMAND cd /build && $(MAKE) install + UPDATE_COMMAND "" +) + +ADD_LIBRARY(mkldnn SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET mkldnn PROPERTY IMPORTED_LOCATION ${MKLDNN_LIBRARY}) +ADD_DEPENDENCIES(mkldnn ${MKLDNN_PROJECT}) +MESSAGE(STATUS "Mkldnn library: ${MKLDNN_LIBRARY}") +LIST(APPEND external_project_dependencies mkldnn) diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake new file mode 100644 index 0000000000000000000000000000000000000000..3f940756a4abb79aba7d3561db19db8532a0b673 --- /dev/null +++ b/cmake/external/mklml.cmake @@ -0,0 +1,64 @@ +# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +IF(NOT ${WITH_MKLML}) + return() +ENDIF(NOT ${WITH_MKLML}) + +INCLUDE(ExternalProject) + +SET(MKLML_PROJECT "extern_mklml") +SET(MKLML_VER "mklml_lnx_2018.0.20170425") +SET(MKLML_URL "https://github.com/01org/mkl-dnn/releases/download/v0.9/${MKLML_VER}.tgz") +SET(MKLML_SOURCE_DIR "${THIRD_PARTY_PATH}/mklml") +SET(MKLML_DOWNLOAD_DIR "${MKLML_SOURCE_DIR}/src/${MKLML_PROJECT}") +SET(MKLML_DST_DIR "opt/paddle/third_party/mklml") +SET(MKLML_INSTALL_ROOT "${CMAKE_INSTALL_PREFIX}") +IF(NOT "$ENV{HOME}" STREQUAL "/root") + SET(MKLML_INSTALL_ROOT "$ENV{HOME}") +ENDIF() + +SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) +SET(MKLML_ROOT ${MKLML_INSTALL_DIR}/${MKLML_VER}) +SET(MKLML_INC_DIR ${MKLML_ROOT}/include) +SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) +SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) +SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") + +INCLUDE_DIRECTORIES(${MKLML_INC_DIR}) + +SET(mklml_cmakefile ${MKLML_DOWNLOAD_DIR}/CMakeLists.txt) +FILE(WRITE ${mklml_cmakefile} "PROJECT(MKLML)\n" + "cmake_minimum_required(VERSION 3.0)\n" + "install(DIRECTORY ${MKLML_VER}\n" + " DESTINATION ${MKLML_DST_DIR})\n") + +ExternalProject_Add( + ${MKLML_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + PREFIX ${MKLML_SOURCE_DIR} + DOWNLOAD_DIR ${MKLML_DOWNLOAD_DIR} + DOWNLOAD_COMMAND wget --no-check-certificate -O ${MKLML_DOWNLOAD_DIR}/${MKLML_VER}.tgz ${MKLML_URL} + && tar -xzf ${MKLML_DOWNLOAD_DIR}/${MKLML_VER}.tgz + DOWNLOAD_NO_PROGRESS 1 + UPDATE_COMMAND "" + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLML_INSTALL_ROOT} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${MKLML_INSTALL_ROOT} +) + +ADD_LIBRARY(mklml SHARED IMPORTED GLOBAL) +SET_PROPERTY(TARGET mklml PROPERTY IMPORTED_LOCATION ${MKLML_LIB}) +ADD_DEPENDENCIES(mklml ${MKLML_PROJECT}) +LIST(APPEND external_project_dependencies mklml) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index c31e62fc08b531a38a851b71a033e14277eff015..34fd348893058980964d723490d9cc220a157b5a 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -124,6 +124,7 @@ set(GPU_COMMON_FLAGS -Wno-error=literal-suffix -Wno-error=unused-local-typedefs -Wno-error=unused-function # Warnings in Numpy Header. + -Wno-error=array-bounds # Warnings in Eigen::array ) if (APPLE) diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 2599b2950836acd44102265dff8bb903f5c8b371..5f3358c69b3fbbbfcd97a96ab50fde3d8b9efad0 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -61,25 +61,24 @@ struct EigenTensor { } }; +template +struct EigenMatrix : public EigenTensor {}; + template struct EigenVector : public EigenTensor { - // Flatten is to reshape a Tensor into a one dimension EigenVector - using Parent = EigenTensor; - static typename Parent::Type Flatten(Tensor& tensor) { - return Parent::From(tensor, - make_ddim({static_cast(product(tensor.dims_))})); + // Flatten reshapes a Tensor into an EigenVector. + static typename EigenVector::Type Flatten(Tensor& tensor) { + return EigenVector::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); } - static typename Parent::ConstType Flatten(const Tensor& tensor) { - return Parent::From(tensor, - make_ddim({static_cast(product(tensor.dims_))})); + static typename EigenVector::ConstType Flatten(const Tensor& tensor) { + return EigenVector::From( + tensor, make_ddim({static_cast(product(tensor.dims_))})); } }; -template -using EigenMatrix = EigenTensor; - } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 501536657d76cc50b1cc4104007edd4b47758aea..139425b356989f20f035d27ed4b678126d9417d6 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -39,19 +39,22 @@ void PlainNet::CompleteAddOp(bool calc) { output_set.insert(opt); } } + inputs_.reserve(input_set.size()); std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_)); + std::sort(inputs_.begin(), inputs_.end()); outputs_.reserve(output_set.size()); + std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs_)); + std::sort(outputs_.begin(), outputs_.end()); + std::vector tmp_index; tmp_index.reserve(temp_output.size()); - int idx = 0; - for (auto& opt : output_set) { - if (Contains(temp_output, opt)) { - tmp_index.push_back(idx); + int output_len = static_cast(outputs_.size()); + for (int i = 0; i < output_len; ++i) { + if (Contains(temp_output, outputs_[i])) { + tmp_index.push_back(i); } - outputs_.push_back(opt); - ++idx; } attrs_["temporary_index"] = tmp_index; @@ -59,9 +62,12 @@ void PlainNet::CompleteAddOp(bool calc) { std::string PlainNet::DebugString() const { std::ostringstream os; - os << this->type_ << ":" << std::endl; + os << OperatorBase::DebugString() << std::endl; for (auto& op : ops_) { - os << "\t" << op->DebugString() << std::endl; + std::istringstream is(op->DebugString()); + for (std::string line; std::getline(is, line);) { + os << " " << line << std::endl; + } } return os.str(); } diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 93c6fad5d3d9f3de100d30161e6e438eb43816a2..a36f375d2e42ee3c46ddef42954335cba7eb88f2 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -48,25 +48,27 @@ class Tensor { template const T* data() const { - CheckDims(); + EnforceSufficientMemory(); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + offset_); } template T* data() { - CheckDims(); + EnforceSufficientMemory(); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } - template + template ::value>::type* = nullptr> T* mutable_data(DDim dims, platform::Place place) { - set_dims(dims); + Resize(dims); return mutable_data(place); } - template + template ::value>::type* = nullptr> T* mutable_data(platform::Place place) { PADDLE_ENFORCE(product(dims_) > 0, "Tensor's numel must be larger than zero to call " @@ -95,11 +97,9 @@ class Tensor { } template - void ShareDataFrom(const Tensor& src) { - src.CheckDims(); - holder_ = src.holder_; - set_dims(src.dims()); - offset_ = src.offset_; + void ShareDataWith(const Tensor& src) { + src.EnforceSufficientMemory(); + *this = src; } template @@ -107,9 +107,9 @@ class Tensor { PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && platform::is_cpu_place(dst_place), "Tensor::CopyFrom only support CPU now."); - src.CheckDims(); + src.EnforceSufficientMemory(); size_t size = product(src.dims_) * sizeof(T); - set_dims(src.dims()); + Resize(src.dims()); const void* src_ptr = static_cast(src.data()); void* dst_ptr = static_cast(mutable_data(dst_place)); memcpy(dst_ptr, src_ptr, size); @@ -117,34 +117,25 @@ class Tensor { template Tensor Slice(const int& begin_idx, const int& end_idx) const { - CheckDims(); - PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], - "Slice index is less than zero or out of bound."); + EnforceSufficientMemory(); + PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero."); + PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound."); PADDLE_ENFORCE(begin_idx < end_idx, "Begin index must be less than end index."); PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1."); - std::vector d = vectorize(dims_); - int base = 1; - for (size_t i = 1; i < d.size(); ++i) { - base *= d[i]; - } + int base = product(dims_) / dims_[0]; Tensor dst; dst.holder_ = holder_; DDim dst_dims = dims_; dst_dims[0] = end_idx - begin_idx; - dst.set_dims(dst_dims); + dst.Resize(dst_dims); dst.offset_ = offset_ + begin_idx * base * sizeof(T); return dst; } - void set_dims(const DDim& dims) { - if (dims == dims_) { - return; - } - dims_ = dims; - } + void Resize(const DDim& dims) { dims_ = dims; } - DDim dims() const { return dims_; } + const DDim& dims() const { return dims_; } private: // Placeholder hides type T, so it doesn't appear as a template @@ -159,21 +150,9 @@ class Tensor { template struct PlaceholderImpl : public Placeholder { - private: - template - class Deleter { - public: - Deleter(PType place) : place_(place) {} - void operator()(T* ptr) { memory::Free(place_, static_cast(ptr)); } - - private: - PType place_; - }; - - public: PlaceholderImpl(PlaceType place, size_t size) : ptr_(static_cast(memory::Alloc(place, size)), - Deleter(place)), + memory::PODDeleter(place)), place_(place), size_(size) {} @@ -182,13 +161,13 @@ class Tensor { virtual paddle::platform::Place place() const { return place_; } virtual std::type_index type() const { return std::type_index(typeid(T)); } - std::unique_ptr> ptr_; + std::unique_ptr> ptr_; platform::Place place_; // record the place of ptr_. size_t size_; // size of the memory block. }; template - inline void CheckDims() const { + inline void EnforceSufficientMemory() const { PADDLE_ENFORCE(holder_ != nullptr, "Tenosr holds no memory. Call Tensor::mutable_data first."); PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_, @@ -198,7 +177,11 @@ class Tensor { std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; - size_t offset_; // marks the begin of tensor data area. + // A PlaceHolder may be shared by more than one tensor. Some of them may be + // slices of the others. So the offset_ is introduced here to indicate the + // byte offset between PlaceHolder::ptr_ and where tensor's data really + // begins. + size_t offset_; }; } // namespace framework diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 8a7cbbd0de6fd6aaafa8649abb8628e971bc49c1..089844dc0164dae8067846a8e6846d47fb1b0833 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -19,7 +19,7 @@ TEST(Tensor, Dims) { using namespace paddle::framework; using namespace paddle::platform; Tensor tt; - tt.set_dims(make_ddim({2, 3, 4})); + tt.Resize(make_ddim({2, 3, 4})); DDim dims = tt.dims(); ASSERT_EQ(arity(dims), 3); for (int i = 0; i < 3; ++i) { @@ -97,7 +97,7 @@ TEST(Tensor, MutableData) { #endif } -TEST(Tensor, ShareDataFrom) { +TEST(Tensor, ShareDataWith) { using namespace paddle::framework; using namespace paddle::platform; { @@ -106,7 +106,7 @@ TEST(Tensor, ShareDataFrom) { // Try to share data form uninitialized tensor bool caught = false; try { - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataWith(src_tensor); } catch (std::runtime_error& err) { caught = true; std::string msg = @@ -119,7 +119,7 @@ TEST(Tensor, ShareDataFrom) { ASSERT_TRUE(caught); src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace()); - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } @@ -128,7 +128,7 @@ TEST(Tensor, ShareDataFrom) { Tensor src_tensor; Tensor dst_tensor; src_tensor.mutable_data(make_ddim({2, 3, 4}), GPUPlace()); - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataWith(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } #endif diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index dfa2f784610b0dd60340e0ebc6a066437f3715eb..7f32c734791853a8cd0287a80a7955dbd1bd7571 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -31,13 +31,22 @@ public: ConvolutionTest(const std::string& conv1, const std::string& conv2, TestType type, + bool useGroups = true, std::string algo = "auto") { for (size_t batchSize : {1, 32}) { for (size_t inputSize : {7, 14, 54}) { for (size_t filterSize : {1, 3, 5}) { for (size_t inputChannels : {3, 64}) { - for (size_t outputChannels : {3, 64, 128}) { - if (inputChannels < outputChannels) break; + for (size_t outputChannels : {3, 64}) { + if (inputChannels > outputChannels) break; + size_t groups; + if (!useGroups) { + groups = 1; + } else { + if (outputChannels % inputChannels != 0) continue; + groups = inputChannels; + } + for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { if (padding >= filterSize) break; @@ -62,13 +71,24 @@ public: FuncConfig() .set("paddings", paddings) .set("strides", strides) - .set("groups", (size_t)1) + .set("groups", groups) .set("algo", algo)); TensorShape input{ batchSize, inputChannels, inputSize, inputSize}; - TensorShape filter{ - outputChannels, inputChannels, filterSize, filterSize}; + + TensorShape filter; + if (groups > 1) + filter = TensorShape({groups, + outputChannels / groups, + inputChannels / groups, + filterSize, + filterSize}); + else + filter = TensorShape({outputChannels, + inputChannels, + filterSize, + filterSize}); TensorShape output{ batchSize, outputChannels, outputSize, outputSize}; @@ -85,7 +105,8 @@ public: } else if (type == kBackwardFilterTest) { test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); - test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter), + ADD_TO); test.run(); } } @@ -106,6 +127,7 @@ public: ConvolutionTest2(const std::string& conv1, const std::string& conv2, TestType type, + bool useGroups = true, std::string algo = "auto") { for (size_t batchSize : {16}) { for (size_t inputHeight : {7, 31}) { @@ -113,7 +135,15 @@ public: for (size_t filterHeight : {1, 5}) { for (size_t filterWidth : {3, 7}) { for (size_t inputChannels : {7}) { - for (size_t outputChannels : {32}) { + for (size_t outputChannels : {7}) { + size_t groups; + if (!useGroups) { + groups = 1; + } else { + if (outputChannels % inputChannels != 0) continue; + groups = inputChannels; + } + size_t stride = 1; size_t padding = 0; size_t outputHeight = @@ -141,13 +171,24 @@ public: FuncConfig() .set("paddings", paddings) .set("strides", strides) - .set("groups", (size_t)1) + .set("groups", groups) .set("algo", algo)); TensorShape input{ batchSize, inputChannels, inputHeight, inputWidth}; - TensorShape filter{ - outputChannels, inputChannels, filterHeight, filterWidth}; + + TensorShape filter; + if (groups > 1) + filter = TensorShape({groups, + outputChannels / groups, + inputChannels / groups, + filterHeight, + filterWidth}); + else + filter = TensorShape({outputChannels, + inputChannels, + filterHeight, + filterWidth}); TensorShape output{ batchSize, outputChannels, outputHeight, outputWidth}; @@ -164,7 +205,8 @@ public: } else if (type == kBackwardFilterTest) { test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); - test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); + test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter), + ADD_TO); test.run(); } } @@ -177,34 +219,88 @@ public: } }; +// ======Start Convolution TEST====== + TEST(Forward, GEMM) { ConvolutionTest test( - "NaiveConv-CPU", "GemmConv-CPU", kForwardTest); + "NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false); ConvolutionTest2 test2( - "NaiveConv-CPU", "GemmConv-CPU", kForwardTest); + "NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false); } #ifndef PADDLE_ONLY_CPU TEST(Forward, GEMM2) { ConvolutionTest test( - "GemmConv-CPU", "GemmConv-GPU", kForwardTest); + "GemmConv-CPU", "GemmConv-GPU", kForwardTest, false); ConvolutionTest2 test2( - "GemmConv-CPU", "GemmConv-GPU", kForwardTest); + "GemmConv-CPU", "GemmConv-GPU", kForwardTest, false); } TEST(BackwardInput, GEMM) { ConvolutionTest test( - "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); + "GemmConvGradInput-CPU", + "GemmConvGradInput-GPU", + kBackwardInputTest, + false); ConvolutionTest2 test2( - "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest); + "GemmConvGradInput-CPU", + "GemmConvGradInput-GPU", + kBackwardInputTest, + false); } TEST(BackwardFilter, GEMM) { ConvolutionTest test( - "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); + "GemmConvGradFilter-CPU", + "GemmConvGradFilter-GPU", + kBackwardFilterTest, + false); ConvolutionTest2 test2( - "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest); + "GemmConvGradFilter-CPU", + "GemmConvGradFilter-GPU", + kBackwardFilterTest, + false); } #endif +// ======End Convolution TEST====== + +// ======Start DepthwiseConvolution TEST====== + +// TODO(zhaolong) The depthwise convolution cpu test will be added when the cpu +// version of depthwiseConv is implemented. + +#ifndef PADDLE_ONLY_CPU + +TEST(DepthwiseConvForward, GEMM2) { + ConvolutionTest test( + "GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest); + ConvolutionTest2 test2( + "GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest); +} + +TEST(DepthwiseConvBackwardInput, GEMM) { + ConvolutionTest test( + "GemmConvGradInput-CPU", + "DepthwiseConvGradInput-GPU", + kBackwardInputTest); + ConvolutionTest2 test2( + "GemmConvGradInput-CPU", + "DepthwiseConvGradInput-GPU", + kBackwardInputTest); +} + +TEST(DepthwiseConvBackwardFilter, GEMM) { + ConvolutionTest test( + "GemmConvGradFilter-CPU", + "DepthwiseConvGradFilter-GPU", + kBackwardFilterTest); + ConvolutionTest2 test2( + "GemmConvGradFilter-CPU", + "DepthwiseConvGradFilter-GPU", + kBackwardFilterTest); +} + +#endif +// ======End DepthwiseConvolution TEST====== } // namespace paddle diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..490e8d546cbd460217abe95f6291b13fa207faa9 --- /dev/null +++ b/paddle/function/DepthwiseConvOp.cpp @@ -0,0 +1,306 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DepthwiseConvOp.h" +#include "ConvOp.h" +#include "GemmFunctor.h" + +namespace paddle { + +template +class DepthwiseConvFunctor { +public: + void operator()(const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData) { + // TODO(zhaolong) : cpu implementation of depthwise convolution + } +}; + +template +class DepthwiseConvGradInputFunctor { +public: + void operator()(const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad) {} + // TODO(zhaolong) : cpu implementation of depthwise convolution +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* filterGrad) {} + // TODO(zhaolong) : cpu implementation of depthwise convolution +}; + +/* + * \brief Forward calculation of depthwise convolution. + */ +template +class DepthwiseConvFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + + DepthwiseConvFunctor depthwiseConv; + depthwiseConv(inputData, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputData); + } +}; + +/* + * \brief Backward input calculation of depthwise convolution. + */ +template +class DepthwiseConvGradInputFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + check(inputs, outputs); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + real* outputGrad = inputs[0].data(); + real* filterData = inputs[1].data(); + real* inputGrad = outputs[0].data(); + + DepthwiseConvGradInputFunctor depthwiseConvGradInput; + depthwiseConvGradInput(outputGrad, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + inputGrad); + } +}; + +/* + * \brief Backward filter calculation of depthwise convolution. + */ +template +class DepthwiseConvGradFilterFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + check(inputs, outputs); + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; + CHECK_EQ(inputChannels, groups_); + + real* outputGrad = inputs[0].data(); + real* inputData = inputs[1].data(); + real* filterGrad = outputs[0].data(); + + int size = outputChannels * filterHeight * filterWidth * outputHeight * + outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + DepthwiseConvGradFilterFunctor depthwiseConvGradFilter; + + depthwiseConvGradFilter(outputGrad, + inputData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + colData, + filterGrad); + } +}; + +REGISTER_TYPED_FUNC(DepthwiseConv, CPU, DepthwiseConvFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradInput, + CPU, + DepthwiseConvGradInputFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, + CPU, + DepthwiseConvGradFilterFunction); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(DepthwiseConv, GPU, DepthwiseConvFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradInput, + GPU, + DepthwiseConvGradInputFunction); +REGISTER_TYPED_FUNC(DepthwiseConvGradFilter, + GPU, + DepthwiseConvGradFilterFunction); +#endif + +} // namespace paddle diff --git a/paddle/function/DepthwiseConvOp.h b/paddle/function/DepthwiseConvOp.h new file mode 100644 index 0000000000000000000000000000000000000000..1bf70e52f34626405b49571e023ac60926713eef --- /dev/null +++ b/paddle/function/DepthwiseConvOp.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "TensorType.h" + +namespace paddle { + +/** + *\brief Depthwise convolution forward. The outputData + * of depthwise convolution is same with ExpandConvLayer + * when groups equals inputChannels in ExpandConvLayer. + * + * \param[in] inputData input data. + * \param[in] filterData the Paramters of the depthwise conv layer.. + * \param[in] batchSize batch size of input data. + * \param[in] outputChannels channels of outputData. + * \param[in] outputHeight height of outputData. + * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of inputData. + * \param[in] inputHeight height of inputData. + * \param[in] inputWidth width of inputData.. + * \param[in] filterMultiplier equals to outputChannels/groups_. + * \param[in] filterHeight height of filter. + * \param[in] filterWidth widht of filter. + * \param[in] strideH stride size in height direction. + * \param[in] strideW stride size in width direction. + * \param[in] paddingH padding size in height direction. + * \param[in] paddingW padding size in width direction. + * \param[out] outputData outputData. + * + */ +template +class DepthwiseConvFunctor { +public: + void operator()(const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData); +}; + +/** + *\brief Functor tot compute the depthwise convolution backprop w.r.t input. + * + * + * \param[in] outputGradData the grad data of output. + * \param[in] filterData the Paramters of the depthwise conv layer.. + * \param[in] batchSize batch size of input data. + * \param[in] outputChannels channels of outputData. + * \param[in] outputHeight height of outputData. + * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of input data. + * \param[in] inputHeight height of inputData. + * \param[in] inputWidth width of inputData. + * \param[in] filterMultiplier equals to outputChannels/groups_. + * \param[in] filterHeight height of filter. + * \param[in] filterWidth widht of filter. + * \param[in] strideH stride size in height direction. + * \param[in] strideW stride size in width direction. + * \param[in] paddingH padding size in height direction. + * \param[in] paddingW padding size in width direction. + * \param[out] inputGrad the grad data of input. + * + */ +template +class DepthwiseConvGradInputFunctor { +public: + void operator()(const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad); +}; + +/** + *\brief Functor tot compute the depthwise convolution backprop w.r.t filter. + * + * \param[in] outputGradData the grad data of output. + * \param[in] inputData inputData. + * \param[in] batchSize batch size of input data. + * \param[in] outputChannels channels of outputData. + * \param[in] outputHeight height of outputData. + * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of input data. + * \param[in] inputHeight height of inputData. + * \param[in] inputWidth width of inputData. + * \param[in] filterMultiplier equals to outputChannels/groups_. + * \param[in] filterHeight height of filter. + * \param[in] filterWidth widht of filter. + * \param[in] strideH stride size in height direction. + * \param[in] strideW stride size in width direction. + * \param[in] paddingH padding size in height direction. + * \param[in] paddingW padding size in width direction. + * \param[in] colData Auxiliary data when calculating filterGrad. + * \param[in] multiplierData Auxiliary data when calculating filterGrad. + * \param[out] filterGrad the grad data of filter. + * + */ +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* filterGrad); +}; + +} // namespace paddle diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..ede0d27aa82e7d71ff5bc33df110fec260e06463 --- /dev/null +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -0,0 +1,342 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "DepthwiseConvOp.h" +#include "GemmFunctor.h" +#include "paddle/math/BaseMatrix.h" + +namespace paddle { + +// CUDA kernel to compute the depthwise convolution forward pass +template +__global__ +void ConvolutionDepthwiseForward(const int nthreads, + const T* const inputData, const T* const filterData, + const int batchSize, const int outputChannels, const int outputHeight, + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const outputData) { + + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + + if (index < nthreads) { + const int batch = index / outputChannels / outputHeight / outputWidth; + const int c_out = (index / outputHeight / outputWidth) % outputChannels; + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; + + const int c_in = c_out / filterMultiplier; + const T* weight = filterData + c_out * filterHeight * filterWidth; + T value = 0; + const int h_in_start = -paddingH + h_out * strideH; + const int w_in_start = -paddingW + w_out * strideW; + const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1; + const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1; + if ((h_in_start >= 0) && (h_in_end < inputHeight) + && (w_in_start >= 0) && (w_in_end < inputWidth)) { + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + const int offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; + value += (*weight) * inputData[offset]; + ++weight; + } + } + } else { + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + if ((h_in >= 0) && (h_in < inputHeight) + && (w_in >= 0) && (w_in < inputWidth)) { + const int offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; + value += (*weight) * inputData[offset]; + } + ++weight; + } + } + } + outputData[index] = value; + } +} + +// CUDA kernel to compute the depthwise convolution backprop w.r.t input. +template +__global__ +void ConvolutionDepthwiseInputBackward(const int nthreads, + const T* const top_diff, const T* const weight_data, + const int num, const int outputChannels, const int outputHeight, + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const bottom_diff) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int batch = index / inputChannels / inputHeight / inputWidth; + const int c_in = (index / inputHeight / inputWidth) % inputChannels; + const int h_in = (index / inputWidth) % inputHeight; + const int w_in = index % inputWidth; + + const int c_out_start = c_in * filterMultiplier; + + int h_out_start = (h_in - filterHeight + paddingH + strideH)/strideH; + h_out_start = 0 > h_out_start ? 0 : h_out_start; + int h_out_end = (h_in + paddingH)/strideH; + h_out_end = outputHeight - 1 < h_out_end? outputHeight - 1 : h_out_end; + int w_out_start = (w_in - filterWidth + paddingW + strideW)/strideW; + w_out_start = 0 > w_out_start ? 0 : w_out_start; + int w_out_end = (w_in + paddingW)/strideW; + w_out_end = outputWidth - 1 < w_out_end? outputWidth - 1 : w_out_end; + + T value = 0; + + for (int c_out = c_out_start; + c_out < c_out_start + filterMultiplier; c_out ++) { + for (int h_out = h_out_start; h_out <= h_out_end; ++h_out) { + const int filter_h = h_in + paddingH - h_out * strideH; + for (int w_out = w_out_start; w_out <= w_out_end; ++w_out) { + const int filter_w = w_in + paddingW - w_out * strideW; + const int filter_offset = c_out * filterHeight * filterWidth + + filter_h * filterWidth + filter_w; + const int top_diff_offset = ((batch * outputChannels + c_out) * + outputHeight + h_out)* outputWidth + w_out; + value += top_diff[top_diff_offset] * weight_data[filter_offset]; + } + } + } + bottom_diff[index] += value; + } +} + +// CUDA kernel to compute the depthwise convolution backprop w.r.t filter. +template +__global__ +void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, + const T* const top_diff, const T* const inputData, + const int num, const int outputChannels, const int outputHeight, + const int outputWidth, const int inputChannels, const int inputHeight, + const int inputWidth, const int filterMultiplier, const int filterHeight, + const int filterWidth, const int strideH, const int strideW, + const int paddingH, const int paddingW, T* const buffer_data) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < nthreads) { + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; + const int kh = (index / filterWidth / outputHeight / outputWidth) + % filterHeight; + const int kw = (index / outputHeight / outputWidth) % filterWidth; + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + if ((h_in >= 0) && (h_in < inputHeight) + && (w_in >= 0) && (w_in < inputWidth)) { + const int c_out = index / + (filterHeight * filterWidth * outputHeight * outputWidth); + const int c_in = c_out / filterMultiplier; + const int batch = num_i; + const int top_offset = ((batch * outputChannels + c_out) * + outputHeight + h_out) * outputWidth + w_out; + const int bottom_offset = ((batch * inputChannels + c_in) + * inputHeight + h_in) * inputWidth + w_in; + buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; + } else { + buffer_data[index] = 0; + } + } +} + +template +class DepthwiseConvFunctor{ +public: + void operator()(const T* inputData, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* outputData){ + int outputSize = batchSize * outputChannels * outputHeight * outputWidth; + + size_t blocks = (outputSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + ConvolutionDepthwiseForward + <<< grid, threads, 0, STREAM_DEFAULT >>>( + outputSize, + inputData, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + outputData); + } +}; + +template +class DepthwiseConvGradInputFunctor{ +public: + void operator()(const T* outputGrad, + const T* filterData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* inputGrad){ + int inputSize = batchSize * inputChannels * inputHeight * inputWidth; + + size_t blocks = (inputSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + + + ConvolutionDepthwiseInputBackward + // NOLINT_NEXT_LINE(whitespace/operators) + <<< grid, threads, 0, STREAM_DEFAULT >>>( + inputSize, + outputGrad, + filterData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + inputGrad); + } +}; + +template +class DepthwiseConvGradFilterFunctor { +public: + void operator()(const T* outputGrad, + const T* inputData, + int batchSize, + int outputChannels, + int outputHeight, + int outputWidth, + int inputChannels, + int inputHeight, + int inputWidth, + int filterMultiplier, + int filterHeight, + int filterWidth, + int strideH, + int strideW, + int paddingH, + int paddingW, + T* colData, + T* filterGrad){ + int colDataSize = outputChannels * filterHeight * filterWidth + * outputHeight * outputWidth; + + size_t blocks = (colDataSize + 1024 -1) / 1024; + size_t blockX = 512; + size_t blockY = (blocks+512-1)/512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, + 1, filterGrad, false, true); + + for (int i = 0; i < batchSize; i++) { + ConvolutionDepthwiseFilterBackward + <<< grid, threads, 0, STREAM_DEFAULT >>>( + i, + colDataSize, + outputGrad, + inputData, + batchSize, + outputChannels, + outputHeight, + outputWidth, + inputChannels, + inputHeight, + inputWidth, + filterMultiplier, + filterHeight, + filterWidth, + strideH, + strideW, + paddingH, + paddingW, + colData); + int K = outputHeight * outputWidth; + int M = colDataSize / K; + + BaseMatrix colMatrix(M, K, colData, false, true); + filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); + } + } +}; + +#ifdef PADDLE_TYPE_DOUBLE +template class DepthwiseConvGradInputFunctor; +template class DepthwiseConvFunctor; +template class DepthwiseConvGradFilterFunctor; +#else +template class DepthwiseConvGradInputFunctor; +template class DepthwiseConvFunctor; +template class DepthwiseConvGradFilterFunctor; +#endif + +} // namespace paddle diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index af79e65a7c09e5a1b55febf1df1e8f5bb61bdcb8..783e02e47cb91e28eb88b079f1e94439d34fa775 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -38,10 +38,25 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, inputShape_.resize(numInputs); filterShape_.resize(numInputs); outputShape_.resize(numInputs); + + std::string convType; + std::string convGradInputType; + std::string convGradFilterType; + for (int i = 0; i < config_.inputs_size(); i++) { std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; + if (useGpu_ && (size_t)groups_[i] == (size_t)channels_[i] && !isDeconv_) { + convType = "DepthwiseConv"; + convGradInputType = "DepthwiseConvGradInput"; + convGradFilterType = "DepthwiseConvGradFilter"; + } else { + convType = "GemmConv"; + convGradInputType = "GemmConvGradInput"; + convGradFilterType = "GemmConvGradFilter"; + } + if (FLAGS_use_nnpack) { CHECK_EQ(isDeconv_, false); createFunction(forward_, @@ -53,21 +68,21 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, .set("algo", std::string("auto"))); } else { createFunction(forward_, - !isDeconv_ ? "GemmConv" : "GemmConvGradInput", + !isDeconv_ ? convType : convGradInputType, FuncConfig() .set("paddings", paddings) .set("strides", strides) .set("groups", (size_t)groups_[i])); createFunction(backward_, - !isDeconv_ ? "GemmConvGradInput" : "GemmConv", + !isDeconv_ ? convGradInputType : convType, FuncConfig() .set("paddings", paddings) .set("strides", strides) .set("groups", (size_t)groups_[i])); createFunction(backward_, - "GemmConvGradFilter", + convGradFilterType, FuncConfig() .set("paddings", paddings) .set("strides", strides) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 9af083468c0f01218117211f9e4931ca0669e96a..0975c3bc9573c6ccb8f0ac98c41586d322d2465e 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -347,6 +347,55 @@ TEST(Layer, CosSimVecMatLayer) { } } +void testDepthwiseConvLayer(const string& type, bool useGpu) { + TestConfig config; + config.biasSize = 32; + config.layerConfig.set_type(type); + config.layerConfig.set_num_filters(32); + config.layerConfig.set_partial_sum(1); + config.layerConfig.set_shared_biases(true); + + config.inputDefs.push_back({INPUT_DATA, "layer_0", 2048, 192}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ConvConfig* conv = input->mutable_conv_conf(); + conv->set_filter_size(2); + conv->set_filter_size_y(3); + conv->set_channels(16); + conv->set_padding(0); + conv->set_padding_y(1); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_groups(16); + conv->set_filter_channels(conv->channels() / conv->groups()); + conv->set_img_size(16); + conv->set_img_size_y(8); + conv->set_output_x(outputSize(conv->img_size(), + conv->filter_size(), + conv->padding(), + conv->stride(), + /* caffeMode */ true)); + conv->set_output_y(outputSize(conv->img_size_y(), + conv->filter_size_y(), + conv->padding_y(), + conv->stride_y(), + /* caffeMode */ true)); + config.layerConfig.set_size(conv->output_x() * conv->output_y() * + config.layerConfig.num_filters()); + + testLayerGrad(config, "depthwise_conv", 100, false, useGpu); + // Use small batch_size and useWeight=true to test biasGrad + testLayerGrad(config, "depthwise_conv", 2, false, useGpu, true, 0.02); +} + +TEST(Layer, depthwiseConvLayer) { + // 'depthwise_conv' is a sepecial case of 'exconv' whose + // groups size equals to the input channels size. + testDepthwiseConvLayer("exconv", /* useGpu= */ false); +#ifndef PADDLE_ONLY_CPU + testDepthwiseConvLayer("exconv", /* useGpu= */ true); +#endif +} + void testConvLayer(const string& type, bool trans, bool useGpu) { TestConfig config; config.biasSize = 16; diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp index 7045562dd44f8f3e0be9181b32954c04f0865fa4..c8ba1074a1555bbddde7e5f0fb2a046138b27c09 100644 --- a/paddle/math/MathFunctions.cpp +++ b/paddle/math/MathFunctions.cpp @@ -202,7 +202,7 @@ double dotProduct(const int n, const double* x, const double* y) { return cblas_ddot(n, x, 1, y, 1); } -#ifdef PADDLE_USE_MKL +#if defined(PADDLE_USE_MKL) || defined(PADDLE_USE_MKLML) template <> void vExp(const int n, const float* a, float* r) { @@ -243,7 +243,55 @@ template <> void vAdd(const int n, const double* a, const double* b, double* r) { vdAdd(n, a, b, r); } +#else + +DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a)); +template +void vExp(const int n, const T* a, T* r) { + hl_cpu_apply_binary_op, 0, 0>( + binary::vExp(), const_cast(a), r, 1, n, n, n); +} + +DEFINE_MATRIX_BINARY_OP(vLog, b = std::log(a)); +template +void vLog(const int n, const T* a, T* r) { + hl_cpu_apply_binary_op, 0, 0>( + binary::vLog(), const_cast(a), r, 1, n, n, n); +} + +DEFINE_MATRIX_BINARY_PARAMETER_OP(vPow, ONE_PARAMETER, b = std::pow(a, p)); +template +void vPow(const int n, const T* a, const T b, T* r) { + hl_cpu_apply_binary_op, 0, 0>( + binary::vPow(b), const_cast(a), r, 1, n, n, n); +} + +DEFINE_MATRIX_TERNARY_OP(vAdd, c = a + b); +template +void vAdd(const int n, const T* a, const T* b, T* r) { + hl_cpu_apply_ternary_op, 0, 0>(ternary::vAdd(), + const_cast(a), + const_cast(b), + r, + 1, + n, + n, + n, + n); +} + +template void vExp(const int n, const float* a, float* r); +template void vExp(const int n, const double* a, double* r); +template void vLog(const int n, const float* a, float* r); +template void vLog(const int n, const double* a, double* r); +template void vPow(const int n, const float* a, const float b, float* r); +template void vPow(const int n, const double* a, const double b, double* r); +template void vAdd(const int n, const float* a, const float* b, float* r); +template void vAdd(const int n, const double* a, const double* b, double* r); +#endif + +#ifdef PADDLE_USE_MKL template <> void vInvSqrt(const int n, const float* a, float* r) { vsInvSqrt(n, a, r); @@ -275,20 +323,6 @@ void vTanh(const int n, const double* a, double* r) { } #else -DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a)); -template -void vExp(const int n, const T* a, T* r) { - hl_cpu_apply_binary_op, 0, 0>( - binary::vExp(), const_cast(a), r, 1, n, n, n); -} - -DEFINE_MATRIX_BINARY_OP(vLog, b = std::log(a)); -template -void vLog(const int n, const T* a, T* r) { - hl_cpu_apply_binary_op, 0, 0>( - binary::vLog(), const_cast(a), r, 1, n, n, n); -} - DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a)); template void vInvSqrt(const int n, const T* a, T* r) { @@ -312,41 +346,12 @@ void vTanh(const int n, const T* a, T* r) { binary::vTanh(), const_cast(a), r, 1, n, n, n); } -DEFINE_MATRIX_BINARY_PARAMETER_OP(vPow, ONE_PARAMETER, b = std::pow(a, p)); -template -void vPow(const int n, const T* a, const T b, T* r) { - hl_cpu_apply_binary_op, 0, 0>( - binary::vPow(b), const_cast(a), r, 1, n, n, n); -} - -DEFINE_MATRIX_TERNARY_OP(vAdd, c = a + b); -template -void vAdd(const int n, const T* a, const T* b, T* r) { - hl_cpu_apply_ternary_op, 0, 0>(ternary::vAdd(), - const_cast(a), - const_cast(b), - r, - 1, - n, - n, - n, - n); -} - -template void vExp(const int n, const float* a, float* r); -template void vExp(const int n, const double* a, double* r); -template void vLog(const int n, const float* a, float* r); -template void vLog(const int n, const double* a, double* r); template void vInvSqrt(const int n, const double* a, double* r); template void vInvSqrt(const int n, const float* a, float* r); template void vLog1p(const int n, const float* a, float* r); template void vLog1p(const int n, const double* a, double* r); template void vTanh(const int n, const float* a, float* r); template void vTanh(const int n, const double* a, double* r); -template void vPow(const int n, const float* a, const float b, float* r); -template void vPow(const int n, const double* a, const double b, double* r); -template void vAdd(const int n, const float* a, const float* b, float* r); -template void vAdd(const int n, const double* a, const double* b, double* r); #endif diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h index 8ada0d34c6733d13a45505492909124010c85a91..637643838ff433753e0cbb9154ee069c2f7c6d15 100644 --- a/paddle/math/MathFunctions.h +++ b/paddle/math/MathFunctions.h @@ -15,6 +15,12 @@ limitations under the License. */ #ifndef MATHFUNCTIONS_H_ #define MATHFUNCTIONS_H_ +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + #ifdef PADDLE_USE_MKL #include #include diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt index fac442cca56b81f56a750bd3b1c2c0911e79e468..8035d93bfec75b20a54c5af0521ab724cafba8ca 100644 --- a/paddle/memory/CMakeLists.txt +++ b/paddle/memory/CMakeLists.txt @@ -1,11 +1,16 @@ add_subdirectory(detail) cc_library(memory SRCS memory.cc) +cc_library(memcpy SRCS memcpy.cc DEPS device_context) cc_library(paddle_memory DEPS - memory meta_data - meta_cache memory_block - buddy_allocator system_allocator) + memory + memcpy + meta_data + meta_cache + memory_block + buddy_allocator + system_allocator) cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory) diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc new file mode 100644 index 0000000000000000000000000000000000000000..098931c887479ce6f1afc8b90e4003758d88c018 --- /dev/null +++ b/paddle/memory/memcpy.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/memory/memcpy.h" + +#include // for memcpy + +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace memory { + +template <> +void Copy(platform::CPUPlace, void* dst, + platform::CPUPlace, + const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +#ifndef PADDLE_ONLY_CPU +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num, + cudaStream_t stream) { + platform::GPUPlaceGuard g(src_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); +} + +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num, + cudaStream_t stream) { + platform::GPUPlaceGuard g(dst_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); +} + +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num, + cudaStream_t stream) { + if (dst_place == src_place) { + platform::GPUPlaceGuard g(src_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); + } else { + platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, + stream); + } +} + +#endif // PADDLE_ONLY_CPU + +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h new file mode 100644 index 0000000000000000000000000000000000000000..99b1c2e1c3e5ae4facaeb4fd0b773a7531448f03 --- /dev/null +++ b/paddle/memory/memcpy.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/gpu_info.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace memory { + +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); + +#ifndef PADDLE_ONLY_CPU +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, + cudaStream_t stream); +#endif // PADDLE_ONLY_CPU + +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index df3d57d629184d28fd42130df9b020a7b52ade72..c2e046926fafd8f4cfc4cd81d8f32e3882ff02ec 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -15,7 +15,8 @@ limitations under the License. */ #include "paddle/memory/memory.h" #include "paddle/memory/detail/buddy_allocator.h" #include "paddle/memory/detail/system_allocator.h" -#include "paddle/platform/assert.h" + +#include // for memcpy namespace paddle { namespace memory { diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 2d6f4fd2a08ee0039647d276476263d0f8d00329..5e0d64707299acb22aacff0fad237c135f614d9c 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -14,19 +14,32 @@ limitations under the License. */ #pragma once +#include "paddle/platform/gpu_info.h" #include "paddle/platform/place.h" namespace paddle { namespace memory { -template +template void* Alloc(Place, size_t); -template +template void Free(Place, void*); -template +template size_t Used(Place); +template ::value>::type* = nullptr> +class PODDeleter { + public: + PODDeleter(Place place) : place_(place) {} + void operator()(T* ptr) { Free(place_, static_cast(ptr)); } + + private: + Place place_; +}; + } // namespace memory } // namespace paddle diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 41d044cdb72b5fb2a7f8654e8ad103778e0857d1..ebe9ceebe488437866fd6097531623eeb547f67a 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -31,7 +31,7 @@ protected: "Inputs/Outputs of AddOp must all be set"); PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), "Two input of Add Op's dimension must be same."); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; @@ -53,6 +53,5 @@ The equation is: Out = X + Y } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); -typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> - AddKernel_CPU_float; -REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); +REGISTER_OP_CPU_KERNEL( + add_two, paddle::operators::AddKernel); diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 0edf142ee4e5f359ea14be02dbf3f7f8855f6db1..2e5a755f92e4d1fa487152ed453fe3b2823062ed 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,6 +1,5 @@ #include "paddle/operators/add_op.h" #include "paddle/framework/op_registry.h" -typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float; REGISTER_OP_GPU_KERNEL(add_two, - AddKernel_GPU_float); \ No newline at end of file + paddle::operators::AddKernel); \ No newline at end of file diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index fe669b03ca498e253bd6c21a4d312f885dee5588..7d7bb09f3d63bef49913c3c7501082c509c45653 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -35,7 +35,7 @@ protected: PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); PADDLE_ENFORCE(outputs[0]->dims().size() == 1, "label's dimension must be 1."); - outputs[0]->set_dims(framework::make_ddim({inputs[0]->dims()[0]})); + outputs[0]->Resize(framework::make_ddim({inputs[0]->dims()[0]})); } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 713b2a5dc83d8dd5a3d944101591d75cb19fe04f..079a5800804345762b0b4bc7b8bc9ca042856ccc 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -12,9 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include +#include "paddle/operators/mul_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" namespace paddle { namespace operators { @@ -33,7 +33,7 @@ protected: dim0[1] == dim1[0], "First matrix's width must be equal with second matrix's height."); PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); - outputs[0]->set_dims({dim0[0], dim1[1]}); + outputs[0]->Resize({dim0[0], dim1[1]}); } }; @@ -57,4 +57,4 @@ The equation is: Out = X * Y REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); REGISTER_OP_CPU_KERNEL( - mul, paddle::operators::MulKernel); + mul, paddle::operators::MulKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 201723df247993c5cc1650edbe4f74441e3217d4..3ee581dc77dc08e6e47b240588811fbc7c6ea303 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -12,9 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/operators/mul_op.h" +#include "paddle/framework/op_registry.h" REGISTER_OP_GPU_KERNEL(mul, paddle::operators::MulKernel); \ No newline at end of file + ::GPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ce8a0169e0cbaafb7e90d2227c9597fff463883d..e6bad7fb9da2d489666aa67f032552e48a86c6cb 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -14,17 +14,30 @@ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { -template +template class MulKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext &context) const override { - LOG(INFO) << "Mul kernel in " << typeid(Place).name(); + void Compute(const framework::KernelContext& context) const override { + Eigen::array, 1> dim_pair = { + {Eigen::IndexPair(1, 0)}}; + + auto input0 = context.Input(0)->Get(); + auto input1 = context.Input(1)->Get(); + auto* output = context.Output(0)->GetMutable(); + + output->mutable_data(context.GetPlace()); + + framework::EigenMatrix::From(*output).device( + *(context.GetEigenDevice())) = + framework::EigenMatrix::From(input0).contract( + framework::EigenMatrix::From(input1), dim_pair); } }; } // namespace operators diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 414bafd0468033813d50d4d6723e68ee9347eaac..e04d69fa72a2f54cc1cc0829d12e0da1609b3383 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/operators/rowwise_add_op.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { @@ -30,7 +30,7 @@ protected: PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; @@ -58,4 +58,4 @@ REGISTER_OP(rowwise_add, paddle::operators::RowWiseAddOpMaker); REGISTER_OP_CPU_KERNEL( rowwise_add, - paddle::operators::RowWiseAddKernel); + paddle::operators::RowWiseAddKernel); diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 2c4bfbf93a1064a47a19c991fa6655b5d67e83cb..5dfac4fd2cf9b7da24dcfa5e7583b9ece12bad1e 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -1,6 +1,6 @@ -#include -#include +#include "paddle/framework/op_registry.h" +#include "paddle/operators/rowwise_add_op.h" REGISTER_OP_GPU_KERNEL( rowwise_add, - paddle::operators::RowWiseAddKernel); + paddle::operators::RowWiseAddKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 35f43e6376be6239021e7a9bacb849b93d5226b5..dc47fe7c847bd0c8c179ac0a5f44b8cc541b47cb 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -13,17 +13,32 @@ limitations under the License. */ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { -template +template class RowWiseAddKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext &context) const override { - LOG(INFO) << "RowWiseAdd kernel in " << typeid(Place).name(); + void Compute(const framework::KernelContext& context) const override { + auto in0 = context.Input(0)->Get(); + auto in1 = context.Input(1)->Get(); + auto* out = context.Output(0)->GetMutable(); + out->mutable_data(context.GetPlace()); + + auto input = framework::EigenMatrix::From(in0); + auto bias = framework::EigenVector::From(in1); + auto output = framework::EigenMatrix::From(*out); + + const int bias_size = bias.dimension(0); + const int rest_size = input.size() / bias_size; + Eigen::DSizes one_d(input.size()); + Eigen::DSizes bcast(rest_size); + output.reshape(one_d).device(*(context.GetEigenDevice())) = + input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d); } }; diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 04df87a3add2af7daa127a072f7b690f6cf94327..66ab1e001142bfb005d3c2e2ea29e01a32dce507 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -31,7 +31,7 @@ protected: PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), "Two input of SGD Op's dimension must be same."); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 45ae277c538ca90716febaf2f3d92b560149d147..91f7d86aebae2e67b2fc18bf2c558fbe2e03de92 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/operators/sigmoid_op.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { @@ -24,7 +24,7 @@ protected: const std::vector &outputs) const override { PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; @@ -34,7 +34,7 @@ public: framework::OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "sigmoid input"); - AddInput("Y", "sigmoid output"); + AddOutput("Y", "sigmoid output"); AddComment("Sigmoid function"); } }; @@ -46,4 +46,5 @@ REGISTER_OP(sigmoid, paddle::operators::SigmoidOp, paddle::operators::SigmoidOpMaker); REGISTER_OP_CPU_KERNEL( - sigmoid, paddle::operators::SigmoidKernel); + sigmoid, + paddle::operators::SigmoidKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/sigmoid_op.cu index 79d5222348f610b1b016a2df06e8b1e0a4fac66c..ed344b2bfd4a9eeef2ce79746bec608469503c9c 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/sigmoid_op.cu @@ -1,5 +1,5 @@ -#include -#include +#include "paddle/operators/sigmoid_op.h" +#include "paddle/framework/op_registry.h" REGISTER_OP_GPU_KERNEL( - sigmoid, paddle::operators::SigmoidKernel); + sigmoid, paddle::operators::SigmoidKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h index 42173343f3e364729ecd190fc554b8c45ecfca8d..2b9356246c471853b53af1d73f8b2a3c206db7ad 100644 --- a/paddle/operators/sigmoid_op.h +++ b/paddle/operators/sigmoid_op.h @@ -14,17 +14,25 @@ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { -template +template class SigmoidKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext &context) const override { - LOG(INFO) << "Sigmoid kernel in " << typeid(Place).name(); + void Compute(const framework::KernelContext& context) const override { + auto input = context.Input(0)->Get(); + auto* output = context.Output(0)->GetMutable(); + + output->mutable_data(context.GetPlace()); + + framework::EigenVector::Flatten(*output).device( + *(context.GetEigenDevice())) = + 1.0 / (1.0 + (-1.0 * framework::EigenVector::Flatten(input)).exp()); } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 4ca7be359e210d7a31aef94e498f37a1ad4879a2..cf5e273de6be71e727f27d5e87d13d9235e31d0c 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/operators/softmax_op.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { @@ -23,9 +23,11 @@ protected: const std::vector &inputs, const std::vector &outputs) const override { PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); + PADDLE_ENFORCE(inputs[0]->dims().size() == 2, + "The input of softmax op must be matrix"); PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); - outputs[0]->set_dims(inputs[0]->dims()); + outputs[0]->Resize(inputs[0]->dims()); } }; @@ -46,4 +48,5 @@ public: namespace ops = paddle::operators; REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); -REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); +REGISTER_OP_CPU_KERNEL(softmax, + ops::SoftmaxKernel); diff --git a/paddle/operators/softmax_op.cu b/paddle/operators/softmax_op.cu index 903eef1b62231d65e2f9ec7a1f57fca0f4c4605c..60676191eb9460868a266d0e4f70357fa78bec2c 100644 --- a/paddle/operators/softmax_op.cu +++ b/paddle/operators/softmax_op.cu @@ -1,5 +1,5 @@ -#include -#include +#include "paddle/framework/op_registry.h" +#include "paddle/operators/softmax_op.h" REGISTER_OP_GPU_KERNEL( - softmax, paddle::operators::SoftmaxKernel); + softmax, paddle::operators::SoftmaxKernel); diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 74e9e2786b11b9a87cd9700d8458d4e611a8d4bb..500c188dbfcf28ae52c2d5b06466539e115acc4a 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -14,17 +14,49 @@ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" namespace paddle { namespace operators { -template +template class SoftmaxKernel : public framework::OpKernel { public: - void Compute(const framework::KernelContext &context) const override { - LOG(INFO) << "Softmax kernel in " << typeid(Place).name(); + void Compute(const framework::KernelContext& context) const override { + auto input = context.Input(0)->Get(); + auto* output = context.Output(0)->GetMutable(); + output->mutable_data(context.GetPlace()); + + auto logits = framework::EigenMatrix::From(input); + auto softmax = framework::EigenMatrix::From(*output); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + + softmax.device(*(context.GetEigenDevice())) = shifted_logits.exp(); + + softmax.device(*(context.GetEigenDevice())) = + (softmax * + softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); } }; } // namespace operators diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 5d440dec48e7a4cba404bc297eca5a451a144d93..b06ab8a2f184e7bb7dd9cb39f377b087c5258dc4 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -43,10 +43,26 @@ namespace platform { // For more details, please check https://stackoverflow.com/a/43870188/724872. #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) +template +inline void throw_on_error(T e) { + throw_on_error(e, ""); +} + +template +inline typename std::enable_if::type throw_on_error( + int stat, const Args&... args) { + if (UNLIKELY(!(stat))) { + throw std::runtime_error( + string::Sprintf(args...) + + string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); + } +} + #ifndef PADDLE_ONLY_CPU template -inline void throw_on_error(cudaError_t e, const Args&... args) { +inline typename std::enable_if::type throw_on_error( + cudaError_t e, const Args&... args) { if (UNLIKELY(e)) { // clang-format off throw thrust::system_error( @@ -58,7 +74,8 @@ inline void throw_on_error(cudaError_t e, const Args&... args) { } template -inline void throw_on_error(curandStatus_t stat, const Args&... args) { +inline typename std::enable_if::type throw_on_error( + curandStatus_t stat, const Args&... args) { if (stat != CURAND_STATUS_SUCCESS) { // clang-format off throw thrust::system_error( @@ -70,7 +87,8 @@ inline void throw_on_error(curandStatus_t stat, const Args&... args) { } template -inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { +inline typename std::enable_if::type throw_on_error( + cudnnStatus_t stat, const Args&... args) { if (stat == CUDNN_STATUS_SUCCESS) { return; } else { @@ -84,7 +102,8 @@ inline void throw_on_error(cudnnStatus_t stat, const Args&... args) { } template -inline void throw_on_error(cublasStatus_t stat, const Args&... args) { +inline typename std::enable_if::type throw_on_error( + cublasStatus_t stat, const Args&... args) { std::string err; if (stat == CUBLAS_STATUS_SUCCESS) { return; @@ -113,15 +132,6 @@ inline void throw_on_error(cublasStatus_t stat, const Args&... args) { #endif // PADDLE_ONLY_CPU -template -inline void throw_on_error(int stat, const Args&... args) { - if (UNLIKELY(!(stat))) { - throw std::runtime_error( - string::Sprintf(args...) + - string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); - } -} - #define PADDLE_THROW(...) \ do { \ throw std::runtime_error( \ @@ -129,12 +139,9 @@ inline void throw_on_error(int stat, const Args&... args) { string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \ } while (0) -/** - * @brief Enforce a condition, otherwise throw an EnforceNotMet - */ -#define PADDLE_ENFORCE(condition, ...) \ - do { \ - ::paddle::platform::throw_on_error(condition, __VA_ARGS__); \ +#define PADDLE_ENFORCE(...) \ + do { \ + ::paddle::platform::throw_on_error(__VA_ARGS__); \ } while (0) } // namespace platform diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index cf9921e870d47fe77c0cca80828dbf2bb36ccda8..edeb3ecd7bf8b87333813eee5b40f71030f6609f 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -44,7 +44,7 @@ void SetDeviceId(int id) { "cudaSetDevice failed in paddle::platform::SetDeviceId"); } -void GpuMemoryUsage(size_t& available, size_t& total) { +void GpuMemoryUsage(size_t &available, size_t &total) { PADDLE_ENFORCE(cudaMemGetInfo(&available, &total), "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage"); } @@ -82,5 +82,28 @@ size_t GpuMaxChunkSize() { return usable; } +void GpuMemcpyAsync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind, cudaStream_t stream) { + PADDLE_ENFORCE(cudaMemcpyAsync(dst, src, count, kind, stream), + "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync"); +} + +void GpuMemcpySync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind) { + PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind), + "cudaMemcpy failed in paddle::platform::GpuMemcpySync"); + // note: cudaMemcpy may actually be asynchronous with respect to the caller, + // block on stream 0 to make sure the copy has completed + PADDLE_ENFORCE( + cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in paddle::platform::GpuMemcpySync"); +} + +void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, + size_t count, cudaStream_t stream) { + PADDLE_ENFORCE( + cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), + "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index 79e71956bd32e8c253ac4192a04e5903bed1c94a..d3a5f5f13fdd3dd59eb43465da4a64b0d8d95e5b 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -16,6 +16,7 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU +#include #include namespace paddle { @@ -31,7 +32,7 @@ int GetCurrentDeviceId(); void SetDeviceId(int device_id); //!Get the memory usage of current GPU device. -void GpuMemoryUsage(size_t& available, size_t& total); +void GpuMemoryUsage(size_t &available, size_t &total); //! Get the maximum allocation size of current GPU device. size_t GpuMaxAllocSize(); @@ -42,6 +43,18 @@ size_t GpuMinChunkSize(); //! Get the maximum chunk size for GPU buddy allocator. size_t GpuMaxChunkSize(); +//! Copy memory from address src to dst asynchronously. +void GpuMemcpyAsync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind, cudaStream_t stream); + +//! Copy memory from address src to dst synchronously. +void GpuMemcpySync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind); + +//! Copy memory from one device to another device. +void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, + size_t count, cudaStream_t stream); + } // namespace platform } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 4db9cc74465629a6b086c3b1f38d7b99038c7361..2c843839ce36d3c2c60f98957b53548f4b9b96b5 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include -#include -#include -#include -#include -#include -#include #include #include +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" +#include "paddle/framework/scope.h" +#include "paddle/pybind/tensor_bind.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + namespace py = pybind11; namespace pd = paddle::framework; @@ -30,9 +32,24 @@ USE_OP(add_two); USE_OP(onehot_cross_entropy); USE_OP_WITHOUT_KERNEL(fc); USE_OP(sgd); +USE_OP(mul); +USE_OP(sigmoid); +USE_OP(softmax); +USE_OP(rowwise_add); + +template +void ExposeOperator(ClassType& m) { + m.def("infer_shape", &ClassType::type::InferShape) + .def("run", &ClassType::type::Run) + .def("outputs", + [](const typename ClassType::type& op) -> std::vector { + return op.outputs_; + }) + .def("__str__", &ClassType::type::DebugString); +} PYBIND11_PLUGIN(core) { - py::module m("core", "C++ core of Paddle Paddle"); + py::module m("core", "C++ core of PaddlePaddle"); py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer([](pd::Tensor& self) -> py::buffer_info { @@ -42,7 +59,7 @@ PYBIND11_PLUGIN(core) { [](const pd::Tensor& self) { return pd::vectorize(self.dims()); }) .def("set_dims", [](pd::Tensor& self, const std::vector& dim) { - self.set_dims(pd::make_ddim(dim)); + self.Resize(pd::make_ddim(dim)); }) .def("alloc_float", [](pd::Tensor& self) { @@ -109,21 +126,37 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::CPUDeviceContext(); }); - py::class_(m, "Operator") - .def("__str__", &pd::OperatorBase::DebugString) + py::class_ operator_base(m, "Operator"); + + operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }); + ExposeOperator(operator_base); + + using PlainNetPtr = std::shared_ptr; + py::class_ plain_net(m, "PlainNet"); + + plain_net .def_static("create", - [](py::bytes protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); + []() -> std::shared_ptr { + auto retv = std::make_shared(); + retv->type_ = "plain_net"; + return retv; }) - .def("infer_shape", &pd::OperatorBase::InferShape) - .def("run", &pd::OperatorBase::Run) - .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); + .def("add_op", &pd::PlainNet::AddOp) + .def("add_op", + [](PlainNetPtr& self, const PlainNetPtr& plain_net) -> void { + self->AddOp(std::static_pointer_cast(plain_net)); + }) + .def("complete_add_op", &pd::PlainNet::CompleteAddOp) + .def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); }); + ExposeOperator(plain_net); return m.ptr(); } diff --git a/paddle/pybind/tensor_bind.h b/paddle/pybind/tensor_bind.h index b96516643ab55b9615ccafdc41d3290590987d95..995e102bf9d342e1604f5ae704288d6cf68d97a4 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/pybind/tensor_bind.h @@ -86,7 +86,7 @@ void PyTensorSetFromArray( dims.push_back((int)array.shape()[i]); } - self.set_dims(framework::make_ddim(dims)); + self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(paddle::platform::CPUPlace()); std::memcpy(dst, array.data(), sizeof(T) * array.size()); } diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index ab81e67579e39a34e3ace18d14434eb86b66fa5b..fc112f1327f5ad5f1bdd04873394b1fa0e761e29 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3219,6 +3219,10 @@ def ParameterHook(type, **kwargs): if sparsity_ratio is not None: hook.sparsity_ratio = sparsity_ratio return hook + elif type == 'dpruning': + hook = ParameterUpdaterHookConfig() + hook.type = type + return hook else: return None diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 01838b40bd123f7e95bb961e4c8ea344a399bad4..b3eb2ef8a8966318fe33ca8b3032a4120c73909f 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,15 @@ -add_python_test(test_framework test_protobuf.py test_scope.py - test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py test_fc_op.py test_add_two_op.py test_sgd_op.py test_cross_entropy_op.py) +add_python_test(test_framework + test_protobuf.py + test_scope.py + test_default_scope_funcs.py + test_op_creation_methods.py + test_plain_net.py + test_tensor.py + test_fc_op.py + test_add_two_op.py + test_sgd_op.py + test_cross_entropy_op.py + test_mul_op.py + test_sigmoid_op.py + test_softmax_op.py + test_rowwise_add_op.py) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index b1fa12cc89fa724994ea482ab0a3d78c03a9cdf0..7b62313f8aca5e9f515d1a9e6df3bb6f51b974fb 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -56,7 +56,10 @@ class OpTestMeta(type): for out_name in func.all_output_args: actual = numpy.array(scope.get_var(out_name).get_tensor()) expect = getattr(self, out_name) - numpy.testing.assert_almost_equal(actual, expect) + # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul + # has some diff, and could not pass unittest. So I set decimal 3 here. + # And I will check this in future. + numpy.testing.assert_almost_equal(actual, expect, decimal=3) obj.test_all = test_all return obj diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0a87e66cd03af1bf84be8ffe111e4a8c3a24d6dc --- /dev/null +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -0,0 +1,17 @@ +import unittest +from op_test_util import OpTestMeta +import numpy as np + + +class TestMulOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "mul" + self.X = np.random.random((32, 784)).astype("float32") + self.Y = np.random.random((784, 100)).astype("float32") + self.Out = np.dot(self.X, self.Y) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_plain_net.py b/python/paddle/v2/framework/tests/test_plain_net.py new file mode 100644 index 0000000000000000000000000000000000000000..2b919aca28902706f8aa285213d6bb1fa2cd3e14 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_plain_net.py @@ -0,0 +1,30 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.create_op_creation_methods import op_creations +import unittest + + +class TestNet(unittest.TestCase): + def test_net_all(self): + net = core.PlainNet.create() + op1 = op_creations.add_two(X="X", Y="Y", Out="Out") + net.add_op(op1) + + net2 = core.PlainNet.create() + net2.add_op(op_creations.fc(X="X", W="w", Y="fc.out")) + net2.complete_add_op(True) + net.add_op(net2) + net.complete_add_op(True) + + expected = ''' +Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out). + Op(add_two), inputs:(X, Y), outputs:(Out). + Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out). + Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0). + Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0). + Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out). +''' + self.assertEqual(expected, "\n" + str(net)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1514983c03f822f84b85437d1cfe653b6a1a2e --- /dev/null +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -0,0 +1,17 @@ +import unittest +from op_test_util import OpTestMeta +import numpy as np + + +class TestRowwiseAddOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "rowwise_add" + self.X = np.random.random((32, 784)).astype("float32") + self.b = np.random.random(784).astype("float32") + self.Out = np.add(self.X, self.b) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py new file mode 100644 index 0000000000000000000000000000000000000000..50044a122f1d66dd54a24f6cce76074a60ee2262 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -0,0 +1,16 @@ +import unittest +from op_test_util import OpTestMeta +import numpy as np + + +class TestSigmoidOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "sigmoid" + self.X = np.random.random((32, 100)).astype("float32") + self.Y = 1 / (1 + np.exp(-self.X)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py new file mode 100644 index 0000000000000000000000000000000000000000..191b698c1cdec9b86b4ded6b1f743586867ca62f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -0,0 +1,23 @@ +import unittest +from op_test_util import OpTestMeta +import numpy as np + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - np.max(x) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +class TestSoftmaxOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "softmax" + self.X = np.random.random((32, 100)).astype("float32") + self.Y = np.apply_along_axis(stable_softmax, 1, self.X) + + +if __name__ == '__main__': + unittest.main()