diff --git a/.travis.yml b/.travis.yml index eecf5e81f0c952cb4cf7bd215496350d14ed7f85..0705baa1aca8b480b2a774076bd91fb9df401a53 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,9 @@ language: cpp -cache: ccache +cache: + directories: + - $HOME/third_party + - $HOME/.ccache + - $HOME/.cache/pip sudo: required dist: trusty os: @@ -35,6 +39,7 @@ addons: - clang-format-3.8 - automake - libtool + - ccache before_install: - | if [ ${JOB} == "BUILD_AND_TEST" ]; then diff --git a/CMakeLists.txt b/CMakeLists.txt index 59182d299be1ccc5f57e22f325b7f684fdf97866..15e310a6ae1155796687f18f7797ae48c8a5ecbf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,16 @@ option(WITH_DOC "Compile PaddlePaddle with documentation" OFF) option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF) + +# CMAKE_BUILD_TYPE +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING + "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" + FORCE) +endif() + +set(THIRD_PARTY_PATH "${PROJ_ROOT}/third_party" CACHE STRING + "A path setting third party libraries download & build directories.") ######################################################################################## include(external/zlib) # download, build, install zlib diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index d38b7d1ba2a74d5bb46d0c07e3abe6832d4c8af3..2a49d76eb30f592a28746f5897b14b7dd319d784 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -14,8 +14,8 @@ INCLUDE(ExternalProject) -SET(GFLAGS_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/gflags) -SET(GFLAGS_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/gflags) +SET(GFLAGS_SOURCES_DIR ${THIRD_PARTY_PATH}/gflags) +SET(GFLAGS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gflags) SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE) IF(WIN32) set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE) diff --git a/cmake/external/glog.cmake b/cmake/external/glog.cmake index bec69f3ddf093b62f084f9080fa1fe4398c93e9a..71e20c85276b014c2e33735c3199c3772526c6c7 100644 --- a/cmake/external/glog.cmake +++ b/cmake/external/glog.cmake @@ -14,8 +14,8 @@ INCLUDE(ExternalProject) -SET(GLOG_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/glog) -SET(GLOG_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/glog) +SET(GLOG_SOURCES_DIR ${THIRD_PARTY_PATH}/glog) +SET(GLOG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/glog) SET(GLOG_INCLUDE_DIR "${GLOG_INSTALL_DIR}/include" CACHE PATH "glog include directory." FORCE) IF(WIN32) diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake index 2fcb7893fa30e7fcd84b9e860217f82cf01bf89e..11d829a9e2f239848803130505c9862695b25029 100644 --- a/cmake/external/gtest.cmake +++ b/cmake/external/gtest.cmake @@ -16,8 +16,8 @@ IF(WITH_TESTING) ENABLE_TESTING() INCLUDE(ExternalProject) - SET(GTEST_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/gtest) - SET(GTEST_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/gtest) + SET(GTEST_SOURCES_DIR ${THIRD_PARTY_PATH}/gtest) + SET(GTEST_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gtest) SET(GTEST_INCLUDE_DIR "${GTEST_INSTALL_DIR}/include" CACHE PATH "gtest include directory." FORCE) INCLUDE_DIRECTORIES(${GTEST_INCLUDE_DIR}) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 66a72cd243e09ccf32b61d419f6d0ad9ec3fe9c8..0e8c29c831c823f701d8eecd954d3b120085e495 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -18,8 +18,8 @@ IF(NOT ${CBLAS_FOUND}) MESSAGE(FATAL_ERROR "Please install OpenBlas, MKL or ATLAS.") INCLUDE(ExternalProject) - SET(CBLAS_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas) - SET(CBLAS_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/openblas) + SET(CBLAS_SOURCES_DIR ${THIRD_PARTY_PATH}/openblas) + SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) IF(WIN32) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 2f2769b4c628d8570c335d344cbf608bda84206f..c0cf2719f9a7b3ae6be5cefffa3dbd2c3f712e82 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -14,8 +14,8 @@ INCLUDE(ExternalProject) -SET(PROTOBUF_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/protobuf) -SET(PROTOBUF_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/protobuf) +SET(PROTOBUF_SOURCES_DIR ${THIRD_PARTY_PATH}/protobuf) +SET(PROTOBUF_INSTALL_DIR ${THIRD_PARTY_PATH}/install/protobuf) SET(PROTOBUF_INCLUDE_DIR "${PROTOBUF_INSTALL_DIR}/include" CACHE PATH "protobuf include directory." FORCE) INCLUDE_DIRECTORIES(${PROTOBUF_INCLUDE_DIR}) diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake index 2f86ab3901d4cfc40a294309662c986b818e64f7..48dec42a3ca7d0116ca07ed22dab713833cae2c2 100644 --- a/cmake/external/python.cmake +++ b/cmake/external/python.cmake @@ -28,8 +28,8 @@ IF(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) FIND_PACKAGE(NumPy REQUIRED) ELSE(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) ##################################### PYTHON ######################################## - SET(PYTHON_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/python) - SET(PYTHON_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/python) + SET(PYTHON_SOURCES_DIR ${THIRD_PARTY_PATH}/python) + SET(PYTHON_INSTALL_DIR ${THIRD_PARTY_PATH}/install/python) SET(_python_DIR ${PYTHON_INSTALL_DIR}) IF(UNIX) diff --git a/cmake/external/swig.cmake b/cmake/external/swig.cmake index 40088c65ef7166ddef52956a1a7470ccab8087c9..63e8bd25462e50e2f78908899938468c989b3ac3 100644 --- a/cmake/external/swig.cmake +++ b/cmake/external/swig.cmake @@ -18,8 +18,8 @@ IF(NOT SWIG_FOUND) # build swig as an external project INCLUDE(ExternalProject) - SET(SWIG_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/swig) - SET(SWIG_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/swig) + SET(SWIG_SOURCES_DIR ${THIRD_PARTY_PATH}/swig) + SET(SWIG_INSTALL_DIR ${THIRD_PARTY_PATH}/install/swig) SET(SWIG_TARGET_VERSION "3.0.2") SET(SWIG_DOWNLOAD_SRC_MD5 "62f9b0d010cef36a13a010dc530d0d41") SET(SWIG_DOWNLOAD_WIN_MD5 "3f18de4fc09ab9abb0d3be37c11fbc8f") diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index 7386d935b8931670d4fd7aa305f74b21471a5562..f5e4b3e1eb39acbe8dbcd0023956ca7e52c1ecd8 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -14,8 +14,8 @@ INCLUDE(ExternalProject) -SET(WARPCTC_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/warpctc) -SET(WARPCTC_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/warpctc) +SET(WARPCTC_SOURCES_DIR ${THIRD_PARTY_PATH}/warpctc) +SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" CACHE PATH "Warp-ctc Directory" FORCE) INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) diff --git a/cmake/external/zlib.cmake b/cmake/external/zlib.cmake index 916f6816aae9938aad95ac527cf07ffbe38f7479..47fa8817fb64fb8fd718e2892ad5bae7bbe956eb 100644 --- a/cmake/external/zlib.cmake +++ b/cmake/external/zlib.cmake @@ -14,8 +14,8 @@ INCLUDE(ExternalProject) -SET(ZLIB_SOURCES_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/zlib) -SET(ZLIB_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/install/zlib) +SET(ZLIB_SOURCES_DIR ${THIRD_PARTY_PATH}/zlib) +SET(ZLIB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/zlib) SET(ZLIB_ROOT ${ZLIB_INSTALL_DIR} CACHE FILEPATH "zlib root directory." FORCE) SET(ZLIB_INCLUDE_DIR "${ZLIB_INSTALL_DIR}/include" CACHE PATH "zlib include directory." FORCE) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 0983d83b73a32d0615170155759d45001cc6ff54..0d1ef5cd8449bd31b4cfa4619f27bce7c1f55ebb 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -3,12 +3,6 @@ include(CheckCXXCompilerFlag) include(CheckCCompilerFlag) include(CheckCXXSymbolExists) -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING - "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" - FORCE) -endif() - function(CheckCompilerCXX11Flag) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index 0cafbd896e2d88aee4406bd0305878ce489bc18d..41beed38a87601cb57072c8966cd0fd2ea156524 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) { a.cpuSequenceDims = m->cast(vec->getSharedPtr()); } +float Arguments::sumCosts() const { + return paddle::Argument::sumCosts(m->outputs); +} + int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) { auto& a = m->getArg(idx); return a.getBatchSize(); diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 364d19f9414430709108824dce75a1007332d824..f5af8b0035b44d97832dd90ca2eeba079503715c 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -450,6 +450,8 @@ public: IVector* vec) throw(RangeError); void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError); + float sumCosts() const; + private: static Arguments* createByPaddleArgumentVector(void* ptr); void* getInternalArgumentsPtr() const; @@ -546,6 +548,10 @@ public: ParameterConfig* getConfig(); void setValueUpdated(); + bool save(const std::string& filename) const; + + bool load(const std::string& filename) const; + size_t getSize() const; private: diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index ddc00d8d1af4c58d7e2233423bea916408bee92b..19f7a898d6b8d3d02c5654559dcb86728266731e 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -57,4 +57,12 @@ size_t Parameter::getID() const { return m->getPtr()->getID(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } +bool Parameter::save(const std::string& filename) const { + return m->getPtr()->save(filename); +} + +bool Parameter::load(const std::string& filename) const { + return m->getPtr()->load(filename); +} + size_t Parameter::getSize() const { return m->getPtr()->getSize(); } diff --git a/paddle/api/test/.gitignore b/paddle/api/test/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b7948824a1eab119140dd9bea20276c303fe4af1 --- /dev/null +++ b/paddle/api/test/.gitignore @@ -0,0 +1,2 @@ +*.w0 +*.wbias diff --git a/paddle/api/test/testArguments.py b/paddle/api/test/testArguments.py index 8cabecd242fb4eb98c0fe468687ef179245e4535..a04a805d7a64ef906c8388f1241b9ef823e4d9e0 100644 --- a/paddle/api/test/testArguments.py +++ b/paddle/api/test/testArguments.py @@ -22,6 +22,8 @@ class TestArguments(unittest.TestCase): args = swig_paddle.Arguments.createArguments(1) args.setSlotValue(0, m) + self.assertAlmostEqual(27.0, args.sumCosts()) + mat = args.getSlotValue(0) assert isinstance(mat, swig_paddle.Matrix) np_mat = mat.toNumpyMatInplace() diff --git a/paddle/api/test/testGradientMachine.py b/paddle/api/test/testGradientMachine.py index b81eafa9673ca34f1b7e06401098d55bdb1b35a5..4b705f66eccd267f326fe0662a17b33a09fda982 100644 --- a/paddle/api/test/testGradientMachine.py +++ b/paddle/api/test/testGradientMachine.py @@ -45,6 +45,7 @@ class TestGradientMachine(unittest.TestCase): assert isinstance(val, swig_paddle.Vector) arr = numpy.full((len(val), ), 0.1, dtype="float32") val.copyFromNumpyArray(arr) + self.assertTrue(param.save(param.getName())) param_config = param.getConfig().toProto() assert isinstance(param_config, paddle.proto.ParameterConfig_pb2.ParameterConfig) @@ -92,6 +93,9 @@ class TestGradientMachine(unittest.TestCase): self.assertTrue(self.isCalled) + for param in machine.getParameters(): + self.assertTrue(param.load(param.getName())) + def test_train_one_pass(self): conf_file_path = './testTrainConfig.py' trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile( diff --git a/paddle/function/BufferArg.cpp b/paddle/function/BufferArg.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65c6f303041d830812fb2d99503b2b2166145f4a --- /dev/null +++ b/paddle/function/BufferArg.cpp @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "BufferArg.h" + +namespace paddle { + +const SequenceArg& BufferArg::sequence() const { + // CHECK_EQ(bufferType_, TENSOR_SEQUENCE_DATA); + return dynamic_cast(*this); +} + +const SparseMatrixArg& BufferArg::sparse() const { + // CHECK_EQ(bufferType_, TENSOR_SPARSE); + return dynamic_cast(*this); +} + +} // namespace paddle diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h new file mode 100644 index 0000000000000000000000000000000000000000..9649913fa8d9bf82b67fc2ac97ae9f30e7029528 --- /dev/null +++ b/paddle/function/BufferArg.h @@ -0,0 +1,281 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "TensorShape.h" +#include "TensorType.h" +#include "paddle/math/CpuSparseMatrix.h" +#include "paddle/math/Matrix.h" +#include "paddle/math/SparseMatrix.h" + +namespace paddle { + +enum BufferType { + TENSOR_NORMAL = 0, + TENSOR_SEQUENCE_ID = 1, + TENSOR_SEQUENCE_DATA = 2, + TENSOR_SPARSE = 3 +}; + +enum SparseDataType { + SPARSE_NO_VALUE = 0, // do not need value pointer, all values are 1 + SPARSE_FLOAT_VALUE = 1 +}; + +enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 }; + +class BufferArg; +class SequenceArg; +class SparseMatrixArg; +typedef std::shared_ptr BufferArgPtr; + +/** + * \brief BufferArg used as the argument type of Function. + * + * The arguments of the Paddle Function have four Buffer types. + * 1. BufferArg for a dense Buffer of any dimension. + * 2. SequenceIdArg for a Buffer of sequence start positions. + * 3. SequenceArg for a Buffer of sequence data. + * 4. SparseMatrixArg for a Buffer of sparse matrix. + * + * There is an ArgType property for the BufferArg used as Function Output. + * Whether the result of the Function calculation is assigned to the + * output Buffer or added to the output Buffer is determined by the + * argType_ property of the output BufferArg. + */ + +// ArgType is only used by output BufferArg. +// For input argument, argType_ is ignored. +// For output argument, need to set the argType_ of the BufferArg. +enum ArgType { + UNSPECIFIED = 0, + ASSIGN_TO = 1, + ADD_TO = 2, +}; +class BufferArg { +public: + void setArgType(ArgType argType) { argType_ = argType; } + + ArgType getArgType() const { return argType_; } + +public: + BufferArg(void* buf, + ValueType valueType, + const TensorShape& shape, + ArgType argType = UNSPECIFIED) + : buf_(buf), valueType_(valueType), shape_(shape), argType_(argType) {} + + BufferArg(void* buf, ValueType valueType) + : buf_(buf), valueType_(valueType) {} + + BufferArg(const Matrix& matrix, ArgType argType = UNSPECIFIED) + : buf_( + const_cast(reinterpret_cast(matrix.getData()))), + valueType_(DataType::value), + shape_(2), + argType_(argType) { + shape_.setDim(0, matrix.getHeight()); + shape_.setDim(1, matrix.getWidth()); + } + + BufferArg(const Matrix& matrix, + const TensorShape& shape, + ArgType argType = UNSPECIFIED) + : buf_( + const_cast(reinterpret_cast(matrix.getData()))), + valueType_(DataType::value), + shape_(shape), + argType_(argType) { + CHECK_EQ(matrix.getElementCnt(), shape.getElements()); + } + + BufferArg(const Vector& vector, ArgType argType = UNSPECIFIED) + : buf_( + const_cast(reinterpret_cast(vector.getData()))), + valueType_(DataType::value), + shape_(1), + argType_(argType) { + shape_.setDim(0, vector.getSize()); + } + + BufferArg(const IVector& vector, ArgType argType = UNSPECIFIED) + : buf_( + const_cast(reinterpret_cast(vector.getData()))), + valueType_(VALUE_TYPE_INT32), + shape_(1), + argType_(argType) { + shape_.setDim(0, vector.getSize()); + } + + template + typename Tensor::Matrix matrix() const { + CHECK(buf_); + CHECK(valueType_ == DataType::value); + // CHECK(deviceType_ == DType); + CHECK_EQ((size_t)2, shape_.ndims()); + return typename Tensor::Matrix( + reinterpret_cast(buf_), shape_[0], shape_[1]); + } + + template + typename Tensor::Vector vector() const { + CHECK(buf_); + CHECK(valueType_ == DataType::value); + // CHECK(deviceType_ == DType); + CHECK_EQ((size_t)1, shape_.ndims()); + return typename Tensor::Vector( + shape_[0], reinterpret_cast(buf_)); + } + + virtual ~BufferArg() {} + + template + T* data() const { + return reinterpret_cast(buf_); + } + + void* data() const { return buf_; } + ValueType valueType() const { return valueType_; } + BufferType bufferType() const { return bufferType_; } + const TensorShape& shape() const { return shape_; } + + const SequenceArg& sequence() const; + const SparseMatrixArg& sparse() const; + +protected: + void* buf_; + ValueType valueType_; + TensorShape shape_; + BufferType bufferType_; + ArgType argType_ = UNSPECIFIED; + // leading dimensions. The size is dims_.size() + // Dims lds_; +}; + +// sequence start positions in a mini-batch of sequences +// shape_.ndims() == 1 +// valueType_ = int32 +// if a < b then value_.buf_[a] < value_.buf_[b] +class SequenceIdArg : public BufferArg { +public: + SequenceIdArg(void* buf, + const TensorShape& shape, + ArgType argType = UNSPECIFIED) + : BufferArg(buf, VALUE_TYPE_INT32, shape, argType) { + CHECK_EQ(shape_.ndims(), (size_t)1); + numSeqs_ = shape_[0] - 1; + } + + SequenceIdArg(const IVector& vector) : BufferArg(vector) { + numSeqs_ = shape_[0] - 1; + } + + ~SequenceIdArg() {} + + size_t numSeqs() const { return numSeqs_; } + +private: + size_t numSeqs_; +}; + +// sequence data +class SequenceArg : public BufferArg { +public: + SequenceArg(void* buf, + ValueType valueType, + const TensorShape& shape, + const SequenceIdArg& startPositions, + ArgType argType = UNSPECIFIED) + : BufferArg(buf, valueType, shape, argType), + startPositions_(startPositions) {} + + SequenceArg(const Matrix& matrix, + const IVector& vector, + ArgType argType = UNSPECIFIED) + : BufferArg(matrix, argType), startPositions_(vector) {} + + ~SequenceArg() {} + + void* getIdBuf() const { return startPositions_.data(); } + size_t numSeqs() const { return startPositions_.numSeqs(); } + +private: + SequenceIdArg startPositions_; +}; + +// sparse matrix +// valueType_ == float or double +// shape_.ndims() == 2 +class SparseMatrixArg : public BufferArg { +public: + SparseMatrixArg(void* buf, + ValueType valueType, + const TensorShape& shape, + const BufferArg& row, + const BufferArg& col, + size_t nnz, + SparseDataFormat format, + SparseDataType type, + ArgType argType = UNSPECIFIED) + : BufferArg(buf, valueType, shape, argType), + row_(row), + col_(col), + nnz_(nnz), + format_(format), + type_(type) { + CHECK((valueType == VALUE_TYPE_FLOAT) || (valueType == VALUE_TYPE_DOUBLE)); + CHECK_EQ(shape_.ndims(), (size_t)2); + CHECK_EQ(row_.shape().ndims(), (size_t)1); + CHECK_EQ(col_.shape().ndims(), (size_t)1); + if (format == SPARSE_CSR_FORMAT) { + CHECK_EQ(nnz, col.shape()[0]); + } else if (format == SPARSE_CSC_FORMAT) { + CHECK_EQ(nnz, row.shape()[0]); + } + } + + SparseMatrixArg(const CpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED) + : BufferArg(sparse, argType), + row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), + col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + + SparseMatrixArg(const GpuSparseMatrix& sparse, ArgType argType = UNSPECIFIED) + : BufferArg(sparse, argType), + row_(reinterpret_cast(sparse.getRows()), VALUE_TYPE_INT32), + col_(reinterpret_cast(sparse.getCols()), VALUE_TYPE_INT32) {} + + ~SparseMatrixArg() {} + + void* getRowBuf() const { return row_.data(); } + + void* getColBuf() const { return col_.data(); } + + size_t nnz() const { return nnz_; } + + SparseDataFormat dataFormat() const { return format_; } + + SparseDataType dataType() const { return type_; } + +private: + BufferArg row_; + BufferArg col_; + size_t nnz_; + SparseDataFormat format_; + SparseDataType type_; +}; + +} // namespace paddle diff --git a/paddle/function/BufferArgTest.cpp b/paddle/function/BufferArgTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a9ee3ab079e339b86a9db8602c41e419df9dc544 --- /dev/null +++ b/paddle/function/BufferArgTest.cpp @@ -0,0 +1,90 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "BufferArg.h" +#include +#include "Function.h" +#include "paddle/math/MemoryHandle.h" + +namespace paddle { + +TEST(BufferTest, BufferArg) { + TensorShape shape({8, 10}); + CpuMemoryHandle memory(shape.getElements() * + sizeOfValuType(VALUE_TYPE_FLOAT)); + BufferArg buffer(memory.getBuf(), VALUE_TYPE_FLOAT, shape); + EXPECT_EQ(buffer.data(), memory.getBuf()); +} + +TEST(BufferTest, SequenceIdArg) { + TensorShape shape({10}); + CpuMemoryHandle memory(shape.getElements() * + sizeOfValuType(VALUE_TYPE_INT32)); + SequenceIdArg buffer(memory.getBuf(), shape); + EXPECT_EQ(buffer.data(), memory.getBuf()); + EXPECT_EQ(buffer.numSeqs(), 9); +} + +TEST(BufferTest, asArgument) { + MatrixPtr matrix = Matrix::create(100, 200); + VectorPtr vector = Vector::create(100, false); + CpuSparseMatrix sparse(200, 300, 50); + + // prepare arguments + BufferArgs argments; + argments.addArg(*matrix); + argments.addArg(*vector); + argments.addArg(sparse); + + // function + auto function = [=](const BufferArgs& inputs) { + EXPECT_EQ(inputs.size(), 3); + + // check inputs[0] + EXPECT_EQ(inputs[0].shape().ndims(), 2); + EXPECT_EQ(inputs[0].shape()[0], 100); + EXPECT_EQ(inputs[0].shape()[1], 200); + EXPECT_EQ(inputs[0].data(), matrix->getData()); + + EXPECT_EQ(inputs[0].matrix().getHeight(), + matrix->getHeight()); + EXPECT_EQ(inputs[0].matrix().getWidth(), + matrix->getWidth()); + EXPECT_EQ(inputs[0].matrix().getData(), matrix->getData()); + + // check inputs[1] + EXPECT_EQ(inputs[1].shape().ndims(), 1); + EXPECT_EQ(inputs[1].shape()[0], 100); + EXPECT_EQ(inputs[1].data(), vector->getData()); + CpuVector inVector = inputs[1].vector(); + EXPECT_EQ(inVector.getSize(), vector->getSize()); + EXPECT_EQ(inVector.getData(), vector->getData()); + + // check inputs[2] + EXPECT_EQ(inputs[2].shape().ndims(), 2); + EXPECT_EQ(inputs[2].shape()[0], 200); + EXPECT_EQ(inputs[2].shape()[1], 300); + EXPECT_EQ(inputs[2].data(), sparse.getData()); + // CHECK_EQ(inputs[2].sparse().nnz(), 50); + // CHECK_EQ(inputs[2].sparse().dataFormat(), SPARSE_CSR_FORMAT); + // CHECK_EQ(inputs[2].sparse().dataType(), SPARSE_FLOAT_VALUE); + EXPECT_EQ(inputs[2].sparse().getRowBuf(), sparse.getRows()); + EXPECT_EQ(inputs[2].sparse().getColBuf(), sparse.getCols()); + }; + + // call function + function(argments); +} + +} // namespace paddle diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index de85eeca821742e1d39d5ce26f873238d4359cba..75a2acc55ec3d33687f96d2b0398e52b69e8680d 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -3,6 +3,7 @@ file(GLOB cpp_files . *Op.cpp) list(APPEND h_files Function.h) list(APPEND cpp_files Function.cpp) +list(APPEND cpp_files BufferArg.cpp) if(WITH_GPU) file(GLOB cu_files . *OpGpu.cu) @@ -18,8 +19,12 @@ if(WITH_TESTING) # TODO: # file(GLOB test_files . *OpTest.cpp) # add_executable(${test_bin} EXCLUDE_FROM_ALL ${test_files}) - add_simple_unittest(CrossMapNormalOpTest) - add_simple_unittest(ContextProjectionOpTest) + # add_simple_unittest(CrossMapNormalOpTest) + add_simple_unittest(TensorShapeTest) + add_simple_unittest(TensorTypeTest) + add_simple_unittest(BufferArgTest) + add_simple_unittest(FunctionTest) + # add_simple_unittest(ContextProjectionOpTest) endif() endif() diff --git a/paddle/function/ContextProjectionOp.cpp b/paddle/function/ContextProjectionOp.cpp index 07907fc1ba7973c728c3a882e4be6b1a7ef7a97a..cb448562ebb37022f727ee65024f06f69d63e9cb 100644 --- a/paddle/function/ContextProjectionOp.cpp +++ b/paddle/function/ContextProjectionOp.cpp @@ -19,17 +19,15 @@ limitations under the License. */ namespace paddle { template <> -void ContextProjectionForward(CpuMatrix* out_mat, - const CpuMatrix* input_mat, - const CpuMatrix* weight_mat, +void ContextProjectionForward(CpuMatrix& out_mat, + const CpuMatrix& input_mat, + const CpuMatrix& weight_mat, const CpuIVector& seq_vec, size_t context_length, int context_start, size_t begin_pad) { const int* starts = seq_vec.getData(); const size_t num_sequences = seq_vec.getSize() - 1; - auto w_mat = const_cast(weight_mat); - auto in_mat = const_cast(input_mat); for (size_t i = 0; i < num_sequences; ++i) { for (size_t j = 0; j < context_length; ++j) { int begin = starts[i] + context_start + j; @@ -39,10 +37,11 @@ void ContextProjectionForward(CpuMatrix* out_mat, if (begin < starts[i]) { int64_t pad_size = std::min(starts[i] - begin, starts[i + 1] - starts[i]); - MatrixPtr mat = out_mat->subMatrix(starts[i], pad_size); - if (w_mat) { - MatrixPtr sub = w_mat->subMatrix(j, pad_size); - mat->addAtOffset(*sub, j * in_mat->getWidth()); + MatrixPtr mat = out_mat.subMatrix(starts[i], pad_size); + if (weight_mat) { + MatrixPtr sub = + const_cast(weight_mat).subMatrix(j, pad_size); + mat->addAtOffset(*sub, j * input_mat.getWidth()); } dst_begin = starts[i] + pad_size; begin = starts[i]; @@ -50,19 +49,22 @@ void ContextProjectionForward(CpuMatrix* out_mat, if (end > starts[i + 1]) { int64_t pad_size = std::min(end - starts[i + 1], starts[i + 1] - starts[i]); - MatrixPtr mat = out_mat->subMatrix(starts[i + 1] - pad_size, pad_size); - if (w_mat) { - MatrixPtr sub = w_mat->subMatrix( - begin_pad + context_start + j - pad_size, pad_size); - mat->addAtOffset(*sub, j * in_mat->getWidth()); + MatrixPtr mat = out_mat.subMatrix(starts[i + 1] - pad_size, pad_size); + if (weight_mat) { + MatrixPtr sub = + const_cast(weight_mat) + .subMatrix(begin_pad + context_start + j - pad_size, + pad_size); + mat->addAtOffset(*sub, j * input_mat.getWidth()); } dst_end = starts[i + 1] - pad_size; end = starts[i + 1]; } if (end <= begin) continue; - MatrixPtr src = in_mat->subMatrix(begin, end - begin); - MatrixPtr dst = out_mat->subMatrix(dst_begin, dst_end - dst_begin); - dst->addAtOffset(*src, j * in_mat->getWidth()); + MatrixPtr src = + const_cast(input_mat).subMatrix(begin, end - begin); + MatrixPtr dst = out_mat.subMatrix(dst_begin, dst_end - dst_begin); + dst->addAtOffset(*src, j * input_mat.getWidth()); } } } @@ -82,40 +84,32 @@ public: begin_pad_ = config.get("begin_pad"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { - CHECK_EQ(3, static_cast(inputs.size())); - CHECK_EQ(1, static_cast(outputs.size())); - CHECK_EQ(0, static_cast(inouts.size())); + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)3, inputs.size()); + CHECK_EQ((size_t)1, outputs.size()); - CHECK(outputs[0].getData() && inputs[0].getData() && inputs[2].getData()); - CHECK_EQ(static_cast(outputs[0].dims_.size()), 2); - CHECK_EQ(static_cast(inputs[0].dims_.size()), 2); - CHECK_EQ(static_cast(inputs[1].dims_.size()), 2); - CHECK_EQ(static_cast(inputs[2].dims_.size()), 1); + CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data()); + CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[0].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[2].shape().ndims(), (size_t)1); /// dim of output = dim of input * context_length - CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_); + CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_); /// dim of input == dim of weight - CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]); + CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]); /// input and output has the same batch_size - CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]); - - auto out_mat = std::make_shared::type>( - outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); - const auto in_mat = std::make_shared::type>( - inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); - const auto w_mat = - !inputs[1].getData() - ? nullptr - : std::make_shared::type>( - inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]); - typename SequenceT::type seq_vec( - inputs[2].dims_[0], reinterpret_cast(inputs[2].getData())); - - ContextProjectionForward(out_mat.get(), - in_mat.get(), - w_mat.get(), + CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); + + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + auto out_mat = outputs[0].matrix(); + auto in_mat = inputs[0].matrix(); + auto w_mat = !inputs[1].data() + ? typename Tensor::Matrix(nullptr, 0, 0) + : inputs[1].matrix(); + auto seq_vec = inputs[2].vector(); + ContextProjectionForward(out_mat, + in_mat, + w_mat, seq_vec, context_length_, context_start_, @@ -129,18 +123,17 @@ private: }; template <> -void ContextProjectionBackward(CpuMatrix* out_grad_mat, - CpuMatrix* in_grad_mat, - CpuMatrix* w_grad_mat, +void ContextProjectionBackward(CpuMatrix& out_grad_mat, + CpuMatrix& in_grad_mat, + CpuMatrix& w_grad_mat, const CpuIVector& seq_vec, size_t context_length, int context_start, size_t begin_pad, bool is_padding, size_t total_pad) { - CHECK(out_grad_mat); - size_t input_dim = in_grad_mat ? in_grad_mat->getWidth() - : w_grad_mat ? w_grad_mat->getWidth() : 0; + size_t input_dim = in_grad_mat ? in_grad_mat.getWidth() + : w_grad_mat ? w_grad_mat.getWidth() : 0; const int* starts = seq_vec.getData(); size_t num_sequences = seq_vec.getSize() - 1; for (size_t i = 0; i < num_sequences; ++i) { @@ -153,8 +146,8 @@ void ContextProjectionBackward(CpuMatrix* out_grad_mat, int64_t pad_size = std::min(starts[i] - begin, starts[i + 1] - starts[i]); if (is_padding && w_grad_mat) { - MatrixPtr mat = out_grad_mat->subMatrix(starts[i], pad_size); - MatrixPtr sub = w_grad_mat->subMatrix(j, pad_size); + MatrixPtr mat = out_grad_mat.subMatrix(starts[i], pad_size); + MatrixPtr sub = w_grad_mat.subMatrix(j, pad_size); sub->addAtOffset(*mat, j * input_dim); } dst_begin = starts[i] + pad_size; @@ -165,8 +158,8 @@ void ContextProjectionBackward(CpuMatrix* out_grad_mat, std::min(end - starts[i + 1], starts[i + 1] - starts[i]); if (is_padding && w_grad_mat) { MatrixPtr mat = - out_grad_mat->subMatrix(starts[i + 1] - pad_size, pad_size); - MatrixPtr sub = w_grad_mat->subMatrix( + out_grad_mat.subMatrix(starts[i + 1] - pad_size, pad_size); + MatrixPtr sub = w_grad_mat.subMatrix( begin_pad + context_start + j - pad_size, pad_size); sub->addAtOffset(*mat, j * input_dim); } @@ -175,8 +168,8 @@ void ContextProjectionBackward(CpuMatrix* out_grad_mat, } if (end <= begin) continue; if (!in_grad_mat) continue; - MatrixPtr src = in_grad_mat->subMatrix(begin, end - begin); - MatrixPtr dst = out_grad_mat->subMatrix(dst_begin, dst_end - dst_begin); + MatrixPtr src = in_grad_mat.subMatrix(begin, end - begin); + MatrixPtr dst = out_grad_mat.subMatrix(dst_begin, dst_end - dst_begin); src->addAtOffset(*dst, j * input_dim); } } @@ -199,44 +192,36 @@ public: total_pad_ = config.get("total_pad"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { - CHECK_EQ(3, static_cast(inputs.size())); - CHECK_EQ(1, static_cast(outputs.size())); - CHECK_EQ(0, static_cast(inouts.size())); + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)3, inputs.size()); + CHECK_EQ((size_t)1, outputs.size()); - CHECK(outputs[0].getData() && inputs[2].getData()); - CHECK_EQ(static_cast(outputs[0].dims_.size()), 2); - CHECK_EQ(static_cast(inputs[0].dims_.size()), 2); - CHECK_EQ(static_cast(inputs[1].dims_.size()), 2); - CHECK_EQ(static_cast(inputs[2].dims_.size()), 1); + CHECK(outputs[0].data() && inputs[2].data()); + CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[0].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[2].shape().ndims(), (size_t)1); /// dim of input == dim of weight - CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]); + CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]); /// input and output has the same batch_size - CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]); + CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); /// dim of output = dim of input * context_length - CHECK_EQ(outputs[0].dims_[1], inputs[0].dims_[1] * context_length_); + CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_); - auto out_grad_mat = std::make_shared::type>( - outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); - auto in_grad_mat = - !inputs[0].getData() - ? nullptr - : std::make_shared::type>( - inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); - auto w_grad_mat = - !inputs[1].getData() - ? nullptr - : std::make_shared::type>( - inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]); - typename SequenceT::type seq_vec( - inputs[2].dims_[0], reinterpret_cast(inputs[2].getData())); + CHECK_EQ(outputs[0].getArgType(), ADD_TO); - ContextProjectionBackward(out_grad_mat.get(), - in_grad_mat ? in_grad_mat.get() : nullptr, - w_grad_mat ? w_grad_mat.get() : nullptr, + auto out_grad_mat = outputs[0].matrix(); + auto in_grad_mat = + !inputs[0].data() ? typename Tensor::Matrix(nullptr, 0, 0) + : inputs[0].matrix(); + auto w_grad_mat = !inputs[1].data() + ? typename Tensor::Matrix(nullptr, 0, 0) + : inputs[1].matrix(); + auto seq_vec = inputs[2].vector(); + ContextProjectionBackward(out_grad_mat, + in_grad_mat, + w_grad_mat, seq_vec, context_length_, context_start_, @@ -253,6 +238,7 @@ private: size_t total_pad_; }; +#if 0 /** * \param inputs[0] input grad. * \param inputs[1] input sequence. @@ -349,6 +335,7 @@ private: size_t begin_pad_; size_t total_pad_; }; +#endif REGISTER_TYPED_FUNC(ContextProjectionForward, CPU, @@ -363,6 +350,7 @@ REGISTER_TYPED_FUNC(ContextProjectionForward, REGISTER_TYPED_FUNC(ContextProjectionBackward, GPU, ContextProjectionBackwardFunc); +#if 0 REGISTER_TYPED_FUNC(ContextProjectionBackwardData, GPU, ContextProjectionBackwardDataFunc); @@ -370,4 +358,5 @@ REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight, GPU, ContextProjectionBackwardWeightFunc); #endif +#endif } // namespace paddle diff --git a/paddle/function/ContextProjectionOp.h b/paddle/function/ContextProjectionOp.h index 93eb050fde35f474750f3c2efa72b7471f654b75..a558df5e072f2f4dcc5c45afa385b3cf88872d26 100644 --- a/paddle/function/ContextProjectionOp.h +++ b/paddle/function/ContextProjectionOp.h @@ -31,14 +31,15 @@ namespace paddle { * \param[in] is_padding whether padding 0 or not. * */ -template -void ContextProjectionForward(typename MatrixT::type* output, - const typename MatrixT::type* input, - const typename MatrixT::type* weight, - const typename SequenceT::type& sequence, - size_t context_length, - int context_start, - size_t begin_pad); +template +void ContextProjectionForward( + typename Tensor::Matrix& output, + const typename Tensor::Matrix& input, + const typename Tensor::Matrix& weight, + const typename Tensor::Vector& sequence, + size_t context_length, + int context_start, + size_t begin_pad); /** * \brief Context Projection Backward. @@ -53,30 +54,31 @@ void ContextProjectionForward(typename MatrixT::type* output, * \param[in] is_padding whether padding 0 or not. * */ -template -void ContextProjectionBackward(typename MatrixT::type* out_grad, - typename MatrixT::type* in_grad, - typename MatrixT::type* w_grad, - const typename SequenceT::type& seq_vec, - size_t context_length, - int context_start, - size_t begin_pad, - bool is_padding, - size_t total_pad); +template +void ContextProjectionBackward( + typename Tensor::Matrix& out_grad, + typename Tensor::Matrix& in_grad, + typename Tensor::Matrix& w_grad, + const typename Tensor::Vector& seq_vec, + size_t context_length, + int context_start, + size_t begin_pad, + bool is_padding, + size_t total_pad); -template +template void ContextProjectionBackwardData( - typename MatrixT::type* out_grad, - typename MatrixT::type* in_grad, - const typename SequenceT::type& sequence, + typename Tensor::Matrix& out_grad, + typename Tensor::Matrix& in_grad, + const typename Tensor::Vector& sequence, size_t context_length, int context_start); -template +template void ContextProjectionBackwardWeight( - typename MatrixT::type* out_grad, - typename MatrixT::type* w_grad, - const typename SequenceT::type& seq_vec, + typename Tensor::Matrix& out_grad, + typename Tensor::Matrix& w_grad, + const typename Tensor::Vector& seq_vec, size_t context_length, int context_start, size_t total_pad, diff --git a/paddle/function/ContextProjectionOpGpu.cu b/paddle/function/ContextProjectionOpGpu.cu index 1ec7058f96c8200728e5add051d5fa6a77a97e36..6a4a01a6510416fc1f945305203f55ece7a28f11 100644 --- a/paddle/function/ContextProjectionOpGpu.cu +++ b/paddle/function/ContextProjectionOpGpu.cu @@ -120,20 +120,19 @@ void hl_context_projection_forward(const real* input, } template <> -void ContextProjectionForward(GpuMatrix* output, - const GpuMatrix* input, - const GpuMatrix* weight, +void ContextProjectionForward(GpuMatrix& output, + const GpuMatrix& input, + const GpuMatrix& weight, const GpuIVector& sequence, size_t context_length, int context_start, size_t begin_pad) { - CHECK(input && output); - hl_context_projection_forward(input->getData(), + hl_context_projection_forward(input.getData(), sequence.getData(), - weight ? weight->getData() : nullptr, - output->getData(), + weight ? weight.getData() : nullptr, + output.getData(), sequence.getSize() - 1, - input->getWidth(), + input.getWidth(), context_length, context_start, begin_pad); @@ -217,17 +216,16 @@ void hl_context_projection_backward_data(real* out_grad, } template <> -void ContextProjectionBackwardData(GpuMatrix* out_grad, - GpuMatrix* in_grad, +void ContextProjectionBackwardData(GpuMatrix& out_grad, + GpuMatrix& in_grad, const GpuIVector& sequence, size_t context_length, int context_start) { - CHECK(in_grad && out_grad); - hl_context_projection_backward_data(out_grad->getData(), + hl_context_projection_backward_data(out_grad.getData(), sequence.getData(), - in_grad->getData(), + in_grad.getData(), sequence.getSize() - 1, - in_grad->getWidth(), + in_grad.getWidth(), context_length, context_start); } @@ -348,19 +346,18 @@ void hl_context_projection_backward_weight(real* out_grad, template <> void ContextProjectionBackwardWeight( - GpuMatrix* out_grad, - GpuMatrix* w_grad, + GpuMatrix& out_grad, + GpuMatrix& w_grad, const GpuIVector& seq_vec, size_t context_length, int context_start, size_t total_pad, size_t begin_pad) { - CHECK(out_grad && w_grad); - hl_context_projection_backward_weight(out_grad->getData(), + hl_context_projection_backward_weight(out_grad.getData(), seq_vec.getData(), - w_grad->getData(), + w_grad.getData(), seq_vec.getSize() - 1, - w_grad->getWidth(), + w_grad.getWidth(), total_pad, context_length, context_start, @@ -368,16 +365,15 @@ void ContextProjectionBackwardWeight( } template <> -void ContextProjectionBackward(GpuMatrix* out_grad, - GpuMatrix* in_grad, - GpuMatrix* w_grad, +void ContextProjectionBackward(GpuMatrix& out_grad, + GpuMatrix& in_grad, + GpuMatrix& w_grad, const GpuIVector& sequence, size_t context_length, int context_start, size_t begin_pad, bool is_padding, size_t total_pad) { - CHECK(out_grad); if (in_grad) { ContextProjectionBackwardData( out_grad, diff --git a/paddle/function/CrossMapNormalOp.cpp b/paddle/function/CrossMapNormalOp.cpp index 96a7a30eebbf0f01fa89ea91110ddb826fd2f64b..92980c503fdaaaa9ac600070197dba6ba4bfb7a4 100644 --- a/paddle/function/CrossMapNormalOp.cpp +++ b/paddle/function/CrossMapNormalOp.cpp @@ -112,6 +112,8 @@ void CrossMapNormalGrad(real* inputsGrad, } /** + * \brief {o_0, o_1} = calc(i_0) + * * \param inputs[0] input value. * \param outputs[0] output value. * \param outputs[1] denoms. @@ -125,27 +127,24 @@ public: pow_ = config.get("pow"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { - CHECK_EQ(1, static_cast(inputs.size())); - CHECK_EQ(2, static_cast(outputs.size())); - CHECK_EQ(0, static_cast(inouts.size())); - - CHECK_EQ(static_cast(inputs[0].dims_.size()), 4); - for (size_t i = 0; i < inputs[0].dims_.size(); i++) { - CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], outputs[1].dims_[i]); - } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)1, inputs.size()); + CHECK_EQ((size_t)2, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == outputs[0].shape()); + CHECK(inputs[0].shape() == outputs[1].shape()); - size_t samples = inputs[0].dims_[0]; - size_t channels = inputs[0].dims_[1]; - size_t height = inputs[0].dims_[2]; - size_t width = inputs[0].dims_[3]; + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO); + size_t samples = inputs[0].shape()[0]; + size_t channels = inputs[0].shape()[1]; + size_t height = inputs[0].shape()[2]; + size_t width = inputs[0].shape()[3]; - CrossMapNormal(outputs[0].getData(), - outputs[1].getData(), - inputs[0].getData(), + CrossMapNormal(outputs[0].data(), + outputs[1].data(), + inputs[0].data(), samples, channels, height, @@ -162,6 +161,8 @@ private: }; /** + * \brief {o_0} = calc(i_0, i_1, i_2, i_3) + * * \param inputs[0] input value. * \param inputs[1] output value. * \param inputs[2] output grad. @@ -177,31 +178,29 @@ public: pow_ = config.get("pow"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { - CHECK_EQ(4, static_cast(inputs.size())); - CHECK_EQ(1, static_cast(outputs.size())); - CHECK_EQ(0, static_cast(inouts.size())); - - CHECK_EQ(static_cast(inputs[0].dims_.size()), 4); - for (size_t i = 0; i < inputs[0].dims_.size(); i++) { - CHECK_EQ(inputs[0].dims_[i], inputs[1].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], inputs[2].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], inputs[3].dims_[i]); - CHECK_EQ(inputs[0].dims_[i], outputs[0].dims_[i]); - } - - size_t samples = inputs[0].dims_[0]; - size_t channels = inputs[0].dims_[1]; - size_t height = inputs[0].dims_[2]; - size_t width = inputs[0].dims_[3]; - - CrossMapNormalGrad(outputs[0].getData(), - inputs[0].getData(), - inputs[1].getData(), - inputs[2].getData(), - inputs[3].getData(), + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ((size_t)4, inputs.size()); + CHECK_EQ((size_t)1, outputs.size()); + + CHECK_EQ(inputs[0].shape().ndims(), (size_t)4); + CHECK(inputs[0].shape() == inputs[1].shape()); + CHECK(inputs[0].shape() == inputs[2].shape()); + CHECK(inputs[0].shape() == inputs[3].shape()); + CHECK(inputs[0].shape() == outputs[0].shape()); + + // TODO(hedaoyuan): need support ASSIGN_TO mode. + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + + size_t samples = inputs[0].shape()[0]; + size_t channels = inputs[0].shape()[1]; + size_t height = inputs[0].shape()[2]; + size_t width = inputs[0].shape()[3]; + + CrossMapNormalGrad(outputs[0].data(), + inputs[0].data(), + inputs[1].data(), + inputs[2].data(), + inputs[3].data(), samples, channels, height, diff --git a/paddle/function/Function.cpp b/paddle/function/Function.cpp index 614e76b8ac0c9a9145a27f5b532ea63bef7f90f0..dbe3a4e9f608df6333a5637f2d962a555b04d7c3 100644 --- a/paddle/function/Function.cpp +++ b/paddle/function/Function.cpp @@ -76,6 +76,20 @@ FuncConfig& FuncConfig::set(const std::string& key, bool v) { return *this; } +void BufferArgs::addArg(const Matrix& arg, + const TensorShape& shape, + ArgType argType) { + args_.push_back(std::make_shared(arg, shape, argType)); +} + +void BufferArgs::addArg(const CpuSparseMatrix& arg, ArgType argType) { + args_.push_back(std::make_shared(arg, argType)); +} + +void BufferArgs::addArg(const GpuSparseMatrix& arg, ArgType argType) { + args_.push_back(std::make_shared(arg, argType)); +} + ClassRegistrar FunctionBase::funcRegistrar_; } // namespace paddle diff --git a/paddle/function/Function.h b/paddle/function/Function.h index 9e8cbb8e48c30e80c5057fc53c050b67d3957188..249f8f9cfad58bf596e8cdce9188409b5690f969 100644 --- a/paddle/function/Function.h +++ b/paddle/function/Function.h @@ -16,57 +16,17 @@ limitations under the License. */ #include #include +#include "BufferArg.h" #include "paddle/math/Matrix.h" #include "paddle/utils/ClassRegistrar.h" namespace paddle { -enum DeviceType { - DEVICE_TYPE_UNSPECIFIED = 0, - DEVICE_TYPE_CPU = 1, - DEVICE_TYPE_GPU = 2, -}; - -template -struct MatrixT; - -template <> -struct MatrixT { - using type = CpuMatrix; -}; - -template <> -struct MatrixT { - using type = GpuMatrix; -}; - -template -struct SequenceT; - -template <> -struct SequenceT { - using type = CpuIVector; -}; - -template <> -struct SequenceT { - using type = GpuIVector; -}; - -typedef std::vector Dims; - -class Tensor { -public: - Tensor(real* data, const Dims& dim) : buf_(data), dims_(dim) {} - - real* getData() const { return buf_; } - - real* buf_; - Dims dims_; -}; - -typedef std::vector Arguments; - +/** + * Function Configuration. + * The argument type of Function::init. + * Follow-up will consider moving this data structure to Proto inside. + */ class FuncConfig { public: union value { @@ -86,15 +46,70 @@ protected: std::map valueMap_; }; +/** + * Argument type for Function::calc(). + * A BufferArgs contains a set of BufferArg, + * because Function can have multiple inputs and outputs. + */ +class BufferArgs { +public: + BufferArgs() {} + size_t size() const { return args_.size(); } + + // add argument into BufferArgs + // Tensor can be Matrix, Vector, IVector. + // For inputs, do not need argType. + // For outputs, the argType needs to be specified as ASSIGN_TO or ADD_TO. + template + void addArg(const Tensor& arg, ArgType argType = UNSPECIFIED) { + args_.push_back(std::make_shared(arg, argType)); + } + + // Add arg into BufferArgs and reshape the arg. + // + // For example, arg represents an image buffer, + // but Matrix can only represent a two-dimensional Tensor. + // So need an extra argument to describe the shape of the image buffer. + void addArg(const Matrix& arg, + const TensorShape& shape, + ArgType argType = UNSPECIFIED); + + void addArg(const CpuSparseMatrix& arg, ArgType argType = UNSPECIFIED); + void addArg(const GpuSparseMatrix& arg, ArgType argType = UNSPECIFIED); + + // get argument + const BufferArg& operator[](size_t num) const { + CHECK_LT(num, args_.size()); + return *args_[num]; + } + +private: + std::vector args_; +}; + +/** + * \brief Base class for Function. + * The basic Function implementation requires override init and calc interfaces. + * + * Function inputs are readonly, Function outputs have two modes: ASSIGN_TO + * and ADD_TO. + * If output.getArgType() == ASSIGN_TO, this is assign mode, and the calculation + * result of Function assigned to the output BufferArg. + * If output.getArgType() == ADD_TO, this is add mode, and the calculation + * result of Function need added to the output BufferArg. + * + * For example: + * ASSIGN_TO: output = Function(inputs) + * ADD_TO: output += Function(inputs) + * If Function has more than one output, each output can have different modes. + */ class FunctionBase { public: virtual ~FunctionBase() {} virtual void init(const FuncConfig& config) {} - virtual void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) {} + virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} static ClassRegistrar funcRegistrar_; }; diff --git a/paddle/function/FunctionTest.cpp b/paddle/function/FunctionTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ce908320a6f6f764e8fdacc96432aca78d7b2df --- /dev/null +++ b/paddle/function/FunctionTest.cpp @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "Function.h" +#include + +namespace paddle { + +template +void FunctionApi(typename Tensor::Matrix& output, + const typename Tensor::Matrix& input); + +template <> +void FunctionApi(CpuMatrix& output, const CpuMatrix& input) { + EXPECT_EQ(output.getHeight(), 100); + EXPECT_EQ(output.getWidth(), 200); +} + +template <> +void FunctionApi(GpuMatrix& output, const GpuMatrix& input) { + EXPECT_EQ(output.getHeight(), 10); + EXPECT_EQ(output.getWidth(), 20); +} + +template +void Function(const BufferArgs& arguments) { + const auto input = arguments[0].matrix(); + auto output = arguments[1].matrix(); + FunctionApi(output, input); +} + +TEST(Function, BufferArgs) { + CpuMatrix cpuInput = CpuMatrix(100, 200); + CpuMatrix cpuOutput = CpuMatrix(100, 200); + BufferArgs cpuArgments; + cpuArgments.addArg(cpuInput); + cpuArgments.addArg(cpuOutput); + Function(cpuArgments); + + GpuMatrix gpuInput = GpuMatrix(10, 20); + GpuMatrix gpuOutput = GpuMatrix(10, 20); + BufferArgs gpuArgments; + gpuArgments.addArg(gpuInput); + gpuArgments.addArg(gpuOutput); + Function(gpuArgments); +} + +} // namespace paddle diff --git a/paddle/function/TensorShape.h b/paddle/function/TensorShape.h new file mode 100644 index 0000000000000000000000000000000000000000..e491e3f1d6b26e14a5273b3b5a38aec941f5a9e5 --- /dev/null +++ b/paddle/function/TensorShape.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { + +/** + * TensorShape used to represent shape of normal tensor. + */ +class TensorShape { +public: + TensorShape() : ndims_(0), nelements_(0) { initDims(0); } + + TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); }; + + TensorShape(std::initializer_list dims) { + ndims_ = dims.size(); + initDims(ndims_); + dims_.assign(dims); + numElements(); + }; + + TensorShape(const TensorShape& t) + : ndims_(t.ndims_), nelements_(t.nelements_) { + initDims(ndims_); + dims_.assign(t.dims_.begin(), t.dims_.end()); + }; + + // get the size of specified dimension + size_t operator[](size_t dim) const { + CHECK_GE(dim, (size_t)0); + CHECK_LT(dim, ndims_); + return dims_[dim]; + } + + // set the size of specified dimension + void setDim(size_t dim, size_t size) { + CHECK_GE(dim, (size_t)0); + CHECK_LT(dim, ndims_); + dims_[dim] = size; + numElements(); + } + + // number of dimensions of the tensor + size_t ndims() const { return ndims_; } + + size_t getElements() const { return nelements_; } + + bool operator==(const TensorShape& t) const { + if (ndims() != t.ndims()) return false; + for (size_t i = 0; i < ndims(); i++) { + if (dims_[i] != t.dims_[i]) return false; + } + + return true; + } + + bool operator!=(const TensorShape& t) const { return !(*this == t); } + +private: + // compute number of elements + void numElements() { + nelements_ = 1; + for (size_t n = 0; n < ndims_; n++) { + nelements_ *= dims_[n]; + } + } + + // init dims_ + void initDims(size_t ndims) { + size_t count = ndims < 4 ? 4 : ndims; + dims_.assign(count, 1); + } + + // number of dimensions + // ndims_ may be not equeal dims_.size() + size_t ndims_; + // number of elements + size_t nelements_; + std::vector dims_; +}; + +} // namespace paddle diff --git a/paddle/function/TensorShapeTest.cpp b/paddle/function/TensorShapeTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45a2e106e7fc3f0e9e57cf8c2bb549d747f4f49b --- /dev/null +++ b/paddle/function/TensorShapeTest.cpp @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "TensorShape.h" +#include + +namespace paddle { + +TEST(TensorShape, Constructor) { + TensorShape t1; + EXPECT_EQ(t1.ndims(), 0); + EXPECT_EQ(t1.getElements(), 0); + + TensorShape t2(3); + EXPECT_EQ(t2.ndims(), 3); + EXPECT_EQ(t2.getElements(), 1); + + TensorShape t3({8, 10}); + EXPECT_EQ(t3.ndims(), 2); + EXPECT_EQ(t3.getElements(), 80); + + TensorShape t4(t3); + EXPECT_EQ(t4.ndims(), t3.ndims()); + EXPECT_EQ(t4.getElements(), t3.getElements()); + + TensorShape t5({1, 2, 3, 4, 5}); + EXPECT_EQ(t5.ndims(), 5); + EXPECT_EQ(t5.getElements(), 120); +} + +TEST(TensorShape, GetAndSet) { + TensorShape t({1, 2, 3}); + EXPECT_EQ(t.ndims(), 3); + EXPECT_EQ(t.getElements(), 6); + + EXPECT_EQ(t[1], 2); + t.setDim(1, 100); + EXPECT_EQ(t.getElements(), 300); + EXPECT_EQ(t[1], 100); +} + +} // namespace paddle diff --git a/paddle/function/TensorType.h b/paddle/function/TensorType.h new file mode 100644 index 0000000000000000000000000000000000000000..98942cff9e2ea44e78727d66a059ab8cf5f0ef7c --- /dev/null +++ b/paddle/function/TensorType.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/Matrix.h" + +namespace paddle { + +enum ValueType { + VALUE_TYPE_INT32 = 0, + VALUE_TYPE_FLOAT = 1, + VALUE_TYPE_DOUBLE = 2, + VALUE_TYPE_BYTE = 3 +}; + +enum DeviceType { + DEVICE_TYPE_UNSPECIFIED = 0, + DEVICE_TYPE_CPU = 1, + DEVICE_TYPE_GPU = 2 +}; + +inline int sizeOfValuType(ValueType valueType) { + if (valueType == VALUE_TYPE_INT32) { + return 4; + } else if (valueType == VALUE_TYPE_FLOAT) { + return 4; + } else if (valueType == VALUE_TYPE_DOUBLE) { + return 8; + } else { + LOG(FATAL) << "Unknown type: " << valueType; + return 0; + } +} + +template +struct DataType; + +template <> +struct DataType { + static const ValueType value = VALUE_TYPE_FLOAT; +}; + +template <> +struct DataType { + static const ValueType value = VALUE_TYPE_DOUBLE; +}; + +template <> +struct DataType { + static const ValueType value = VALUE_TYPE_INT32; +}; + +namespace detail { + +template +struct MatrixT; + +template <> +struct MatrixT { + using type = CpuMatrix; +}; + +template <> +struct MatrixT { + using type = GpuMatrix; +}; + +template <> +struct MatrixT { + using type = void; // Not implemented +}; + +template <> +struct MatrixT { + using type = void; // Not implemented +}; + +template +struct VectorT; + +template <> +struct VectorT { + using type = CpuVector; +}; + +template <> +struct VectorT { + using type = GpuVector; +}; + +template <> +struct VectorT { + using type = CpuIVector; +}; + +template <> +struct VectorT { + using type = GpuIVector; +}; + +} // namespace detail + +template +struct Tensor { + typedef typename detail::MatrixT::type Matrix; + typedef typename detail::VectorT::type Vector; +}; + +} // namespace paddle diff --git a/paddle/function/TensorTypeTest.cpp b/paddle/function/TensorTypeTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e50e46f3e99111731d9587f3e4ddfd4b26ae27e9 --- /dev/null +++ b/paddle/function/TensorTypeTest.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "TensorType.h" +#include + +namespace paddle { + +TEST(TensorType, Matrix) { + Tensor::Matrix matrix(100, 200); + EXPECT_EQ(matrix.getHeight(), 100); + EXPECT_EQ(matrix.getWidth(), 200); + EXPECT_EQ(matrix.getElementCnt(), 100 * 200); + EXPECT_EQ(matrix.useGpu(), false); + + Tensor::Matrix testGpu(100, 200); + EXPECT_EQ(testGpu.useGpu(), true); +} + +TEST(TensorType, Vector) { + Tensor::Vector cpuVector(100); + Tensor::Vector gpuVector(100); + EXPECT_EQ(cpuVector.useGpu(), false); + EXPECT_EQ(gpuVector.useGpu(), true); + EXPECT_EQ(cpuVector.getSize(), 100); + EXPECT_EQ(gpuVector.getSize(), 100); + + Tensor::Vector cpuIVector(100); + Tensor::Vector gpuIVector(100); + EXPECT_EQ(cpuIVector.useGpu(), false); + EXPECT_EQ(gpuIVector.useGpu(), true); + EXPECT_EQ(cpuIVector.getSize(), 100); + EXPECT_EQ(gpuIVector.getSize(), 100); +} + +TEST(TensorType, EmptyMatrix) { + CpuMatrix empty(nullptr, 0, 0); + CpuMatrix nonEmpty(10, 10); + EXPECT_EQ(empty.isEmpty(), true); + EXPECT_EQ(nonEmpty.isEmpty(), false); + CHECK(nonEmpty); + auto function = [](const CpuMatrix& matrix) { + if (matrix) { + EXPECT_NE(matrix.getData(), nullptr); + } else { + EXPECT_EQ(matrix.getData(), nullptr); + } + }; + function(empty); + function(nonEmpty); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/ContextProjection.cpp b/paddle/gserver/layers/ContextProjection.cpp index ee4db219890a135d786c46827632d02d1db5b760..ebcc87cbf48a3c34a4e625e67f872fed69cdf44f 100644 --- a/paddle/gserver/layers/ContextProjection.cpp +++ b/paddle/gserver/layers/ContextProjection.cpp @@ -110,9 +110,8 @@ void ContextProjection::forward() { size_t input_dim = in_->value->getWidth(); size_t dim = out_->value->getWidth(); CHECK_EQ(dim, input_dim * config_.context_length()); - size_t batch_size = in_->value->getHeight(); - CHECK_EQ(static_cast(forward_.size()), 1) - << "Only one forward function here"; + // size_t batch_size = in_->value->getHeight(); + CHECK_EQ(forward_.size(), (size_t)1) << "Only one forward function here"; REGISTER_TIMER_INFO("ContextProjectionForward", getName().c_str()); bool is_padding = config_.trainable_padding(); @@ -120,14 +119,16 @@ void ContextProjection::forward() { auto w_ptr = state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr; auto start_pos = in_->sequenceStartPositions; - forward_[0]->calc({Tensor(in_->value->getData(), Dims{batch_size, input_dim}), - Tensor(w_ptr ? w_ptr->getData() : nullptr, - Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}), - Tensor(reinterpret_cast( - const_cast(start_pos->getData(useGpu_))), - Dims{start_pos->getSize()})}, - {Tensor(out_->value->getData(), Dims{batch_size, dim})}, - {}); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*in_->value); + inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, + w_ptr ? w_ptr->getHeight() : 0, + input_dim)); + inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_)); + outputs.addArg(*out_->value, ADD_TO); + forward_[0]->calc(inputs, outputs); if (state_ && config_.context_start() < 0) { CHECK_EQ(1, in_->getNumSequences()); @@ -162,15 +163,17 @@ void ContextProjection::backward(const UpdateCallback& callback) { bool is_padding = config_.trainable_padding(); auto start_pos = in_->sequenceStartPositions; auto w_ptr = is_padding ? weight_->getWGrad() : nullptr; - backward_[0]->calc({Tensor(in_->grad ? in_->grad->getData() : nullptr, - Dims{batch_size, input_dim}), - Tensor(w_ptr ? w_ptr->getData() : nullptr, - Dims{w_ptr ? w_ptr->getHeight() : 0, input_dim}), - Tensor(reinterpret_cast( - const_cast(start_pos->getData(useGpu_))), - Dims{start_pos->getSize()})}, - {Tensor(out_->grad->getData(), Dims{batch_size, dim})}, - {}); + + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(CpuMatrix( + in_->grad ? in_->grad->getData() : nullptr, batch_size, input_dim)); + inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr, + w_ptr ? w_ptr->getHeight() : 0, + input_dim)); + inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_)); + outputs.addArg(*out_->grad, ADD_TO); + backward_[0]->calc(inputs, outputs); if (config_.trainable_padding()) { weight_->getParameterPtr()->incUpdate(callback); diff --git a/paddle/gserver/layers/NormProjectionLayer.cpp b/paddle/gserver/layers/NormProjectionLayer.cpp index 262d757c67e105a8d65619eed91de65d34cfe35e..4331009de7e98d2326049e563e46a55a20366507 100644 --- a/paddle/gserver/layers/NormProjectionLayer.cpp +++ b/paddle/gserver/layers/NormProjectionLayer.cpp @@ -59,7 +59,6 @@ bool CMRProjectionNormLayer::init(const LayerMap& layerMap, void CMRProjectionNormLayer::forward(PassType passType) { Layer::forward(passType); - /* malloc memory for the output_ if necessary */ /* note: one sample correspond to one row */ MatrixPtr input = inputLayers_[0]->getOutputValue(); @@ -67,34 +66,36 @@ void CMRProjectionNormLayer::forward(PassType passType) { int size = getSize(); resetOutput(batchSize, size); - MatrixPtr outV = getOutputValue(); - Matrix::resizeOrCreate(denoms_, batchSize, size, /* trans */ false, useGpu_); - dims_ = {batchSize, channels_, imgSizeH_, imgSizeW_}; - forward_[0]->calc( - {Tensor(input->getData(), dims_)}, - {Tensor(outV->getData(), dims_), Tensor(denoms_->getData(), dims_)}, - {}); + shape_ = TensorShape({batchSize, channels_, imgSizeH_, imgSizeW_}); + + // prepare forward arguments + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), shape_); + outputs.addArg(*getOutputValue(), shape_, ASSIGN_TO); + outputs.addArg(*denoms_, shape_, ASSIGN_TO); + + forward_[0]->calc(inputs, outputs); } void CMRProjectionNormLayer::backward(const UpdateCallback& callback) { (void)callback; - if (NULL == inputLayers_[0]->getOutputGrad()) { + if (NULL == getInputGrad(0)) { return; } - /* Do derivation */ - MatrixPtr preOutGrad = inputLayers_[0]->getOutputGrad(); - MatrixPtr localGrad = getOutputGrad(); - MatrixPtr localOutV = getOutputValue(); - MatrixPtr preOutV = inputLayers_[0]->getOutputValue(); - - backward_[0]->calc({Tensor(preOutV->getData(), dims_), - Tensor(localOutV->getData(), dims_), - Tensor(localGrad->getData(), dims_), - Tensor(denoms_->getData(), dims_)}, - {Tensor(preOutGrad->getData(), dims_)}, - {}); + + // prepare backward arguments + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), shape_); + inputs.addArg(*getOutputValue(), shape_); + inputs.addArg(*getOutputGrad(), shape_); + inputs.addArg(*denoms_, shape_); + outputs.addArg(*getInputGrad(0), shape_, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } // namespace paddle diff --git a/paddle/gserver/layers/NormProjectionLayer.h b/paddle/gserver/layers/NormProjectionLayer.h index 6b2c5dde0d74db4b292d5006d19ce54d3194017e..2c0d8a3a718c484508b2bf6d4e7861d54a1682bb 100644 --- a/paddle/gserver/layers/NormProjectionLayer.h +++ b/paddle/gserver/layers/NormProjectionLayer.h @@ -41,6 +41,6 @@ public: void backward(const UpdateCallback& callback = nullptr); protected: - Dims dims_; + TensorShape shape_; }; } // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 90813a89969c2525f7029f1c2609bed116c910c4..3ae237bc7de895293c15eedc811cf8a2011a7c52 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1311,7 +1311,9 @@ void GpuMatrix::paramReluForward(Matrix& data, Matrix& W) { real* w = W.getData(); size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); - size_t partial_sum = numElements / (W.getHeight() * W.getWidth()); + size_t paraSize = W.getHeight() * W.getWidth(); + CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init + size_t partial_sum = numElements / paraSize; real* output = getData(); hl_param_relu_forward(output, input, w, numElements, numSamples, partial_sum); } @@ -1324,7 +1326,9 @@ void GpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) { real* wgrad = data_; size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); - size_t partial_sum = numElements / (this->getHeight() * this->getWidth()); + size_t paraSize = this->getHeight() * this->getWidth(); + CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init + size_t partial_sum = numElements / paraSize; hl_param_relu_backward_w( wgrad, ograd, input, numElements, numSamples, partial_sum); } @@ -1336,7 +1340,9 @@ void GpuMatrix::paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) { real* w = W.getData(); size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); - size_t partial_sum = numElements / (W.getHeight() * W.getWidth()); + size_t paraSize = W.getHeight() * W.getWidth(); + CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init + size_t partial_sum = numElements / paraSize; hl_param_relu_backward_diff( ograd, input, w, diff, numElements, numSamples, partial_sum); } @@ -3764,7 +3770,9 @@ void CpuMatrix::paramReluForward(Matrix& data, Matrix& W) { real* w = W.getData(); size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); - size_t partial_sum = numElements / (W.getHeight() * W.getWidth()); + size_t paraSize = W.getHeight() * W.getWidth(); + CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init + size_t partial_sum = numElements / paraSize; for (size_t n = 0, k = 0; n < numSamples; ++n) { for (size_t i = 0; i < numElements; ++i, ++k) { data_[k] = input[k] > 0 ? input[k] : input[k] * w[i / partial_sum]; @@ -3778,7 +3786,9 @@ void CpuMatrix::paramReluBackwardW(Matrix& oGrad, Matrix& data) { real* wgrad = data_; size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); - size_t partial_sum = numElements / (this->getHeight() * this->getWidth()); + size_t paraSize = this->getHeight() * this->getWidth(); + CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init + size_t partial_sum = numElements / paraSize; for (size_t n = 0, k = 0; n < numSamples; ++n) { for (size_t i = 0; i < numElements; ++i, ++k) { wgrad[i / partial_sum] += ograd[k] * (input[k] > 0 ? 0 : input[k]); @@ -3793,7 +3803,9 @@ void CpuMatrix::paramReluBackwardDiff(Matrix& oGrad, Matrix& data, Matrix& W) { real* w = W.getData(); size_t numElements = data.getWidth(); size_t numSamples = data.getHeight(); - size_t partial_sum = numElements / (W.getHeight() * W.getWidth()); + size_t paraSize = W.getHeight() * W.getWidth(); + CHECK(!(numElements % paraSize)); // this check from ParameterReluLayer::init + size_t partial_sum = numElements / paraSize; for (size_t n = 0, k = 0; n < numSamples; ++n) { for (size_t i = 0; i < numElements; ++i, ++k) { diff[k] += ograd[k] * (input[k] > 0 ? 1 : w[i / partial_sum]); diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index ceac0212d25a53ca77403b57aa66d2607ed41c5a..dd24f8821d49768354840e0381742218ab9a0204 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1091,6 +1091,10 @@ public: TensorCpuApply(*this, expr); } } + + bool isEmpty() const { return data_ == nullptr; } + + explicit operator bool() const { return !isEmpty(); } }; inline std::ostream& operator<<(std::ostream& os, const Matrix& mat) { diff --git a/paddle/math/tests/test_Matrix.cpp b/paddle/math/tests/test_Matrix.cpp index 6899769144dd89156b2ffdb644c47ef0025d624b..a4084bdf7c6953651bfd9714fd8a5c930f774fe6 100644 --- a/paddle/math/tests/test_Matrix.cpp +++ b/paddle/math/tests/test_Matrix.cpp @@ -224,10 +224,11 @@ void testParamReluBackwardW(int height, int width, int w_height, int w_width) { } TEST(Matrix, paramRelu) { - for (auto height : {10, 100}) { - for (auto width : {10, 100}) { + for (auto height : {10, 40, 100}) { + for (auto width : {10, 40, 100}) { for (auto w_height : {1, 2}) { for (auto w_width : {1, 2}) { + if (width % (w_height * w_width)) continue; testParamReluForward(height, width, w_height, w_width); testParamReluBackwardW(height, width, w_height, w_width); } diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 3a780d26c050ac5870824f2ef35c87edc61900a2..f0c49791d7e2a67220eafca3e1347f30958877a7 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -773,10 +773,11 @@ void testParamReluBackwardDiff(int height, } TEST(Matrix, paramReluBackwardDiff) { - for (auto height : {10, 100}) { - for (auto width : {10, 100}) { + for (auto height : {10, 40, 100}) { + for (auto width : {10, 40, 100}) { for (auto w_height : {1, 2}) { for (auto w_width : {1, 2}) { + if (width % (w_height * w_width)) continue; testParamReluBackwardDiff(height, width, w_height, w_width); } } diff --git a/paddle/scripts/travis/build_and_test.sh b/paddle/scripts/travis/build_and_test.sh index c0b3cc294a346609f3de5b0a3307c48b8867cd5d..57af47695cd182867e616ceaf3c0685d9a8fc3d7 100755 --- a/paddle/scripts/travis/build_and_test.sh +++ b/paddle/scripts/travis/build_and_test.sh @@ -6,14 +6,14 @@ if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then export PYTHONPATH=/opt/python/2.7.12/lib/python2.7/site-packages export PYTHONHOME=/opt/python/2.7.12 export PATH=/opt/python/2.7.12/bin:${PATH} - cmake .. -DON_TRAVIS=ON -DWITH_COVERAGE=ON -DCOVERALLS_UPLOAD=ON + cmake .. -DON_TRAVIS=ON -DWITH_COVERAGE=ON -DCOVERALLS_UPLOAD=ON ${EXTRA_CMAKE_OPTS} NRPOC=`nproc` make -j $NPROC make coveralls sudo make install elif [[ "$TRAVIS_OS_NAME" == "osx" ]]; then export PYTHONPATH=/usr/local/lib/python2.7/site-packages - cmake .. -DON_TRAVIS=ON + cmake .. -DON_TRAVIS=ON ${EXTRA_CMAKE_OPTS} NPROC=`sysctl -n hw.ncpu` make -j $NPROC fi diff --git a/paddle/scripts/travis/common.sh b/paddle/scripts/travis/common.sh index 9b6e420ca7931f0d17da461c7579bf4dc69e18e0..f05c7530a3b0632948e4b18c477d6dc6aad04c03 100755 --- a/paddle/scripts/travis/common.sh +++ b/paddle/scripts/travis/common.sh @@ -2,3 +2,5 @@ set -e mkdir -p ../../../build cd ../../../build +mkdir -p $HOME/third_party +EXTRA_CMAKE_OPTS="-DTHIRD_PARTY_PATH=${HOME}/third_party" diff --git a/paddle/scripts/travis/docs.sh b/paddle/scripts/travis/docs.sh index 8690fe1d40c935e119fefbc02f3a228d76d8c0f9..bdafb145bcd4e5990f382bb890f804687c474f7c 100755 --- a/paddle/scripts/travis/docs.sh +++ b/paddle/scripts/travis/docs.sh @@ -4,7 +4,7 @@ source ./common.sh # Compile Documentation only. -cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON +cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_DOC=ON ${EXTRA_CMAKE_OPTS} make paddle_docs paddle_docs_cn # check websites for broken links