diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b6a80ca43cf131c6886455cb5a86a61246ac17c..c5d7f2c7ec76dcc7befcd16798d26a7d54a19328 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF) option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) +option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) @@ -107,6 +108,7 @@ include(configure) # add paddle env configuration include_directories("${PROJ_ROOT}") include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") +include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") set(EXTERNAL_LIBS ${GFLAGS_LIBRARIES} @@ -126,9 +128,12 @@ endif(WITH_GPU) add_subdirectory(proto) add_subdirectory(paddle) -add_subdirectory(go/master/c) add_subdirectory(python) -add_subdirectory(go/pserver/cclient) + +if(WITH_GOLANG) + #TODO (add go/master/c back when fixed) + add_subdirectory(go/pserver/cclient) +endif(WITH_GOLANG) if(WITH_DOC) add_subdirectory(doc) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 5e507e78f74eee885922f502f35e3c15fafb622d..e8425aedbdd269d54035a0457fa37e0ba834427a 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -40,6 +40,10 @@ if(NOT CMAKE_CROSSCOMPILING) endif() endif() +if(NOT WITH_GOLANG) + add_definitions(-DPADDLE_WITHOUT_GOLANG) +endif(NOT WITH_GOLANG) + if(NOT WITH_GPU) add_definitions(-DPADDLE_ONLY_CPU) add_definitions(-DHPPL_STUB_FUNC) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 43cd6b398b1caac55b938d576b96eb0282c00fda..69e8164a00d1fb57b79c63ba88c2846d30d80cd2 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -11,56 +11,164 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -# To simplify the build process of PaddlePaddle, we defined couple of -# fundamental abstractions, e.g., how to build library, binary and -# test in C++, CUDA and Go. +# generic.cmake defines CMakes functions that look like Bazel's +# building rules (https://bazel.build/). # +# # ------------------------------------------- -# C++ CUDA C++ Go +# C++ CUDA C++ Go # ------------------------------------------- -# cc_library nv_library go_library -# cc_binary nv_binary go_binary -# cc_test nv_test go_test +# cc_library nv_library go_library +# cc_binary nv_binary go_binary +# cc_test nv_test go_test # ------------------------------------------- +# +# To build a static library example.a from example.cc using the system +# compiler (like GCC): +# +# cc_library(example SRCS example.cc) +# +# To build a static library example.a from multiple source files +# example{1,2,3}.cc: +# +# cc_library(example SRCS example1.cc example2.cc example3.cc) +# +# To build a shared library example.so from example.cc: +# +# cc_library(example SHARED SRCS example.cc) +# +# To build a library using Nvidia's NVCC from .cu file(s), use the nv_ +# prefixed version: +# +# nv_library(example SRCS example.cu) +# +# To specify that a library new_example.a depends on other libraies: +# +# cc_library(new_example SRCS new_example.cc DEPS example) +# +# Static libraries can be composed of other static libraries: +# +# cc_library(composed DEPS dependent1 dependent2 dependent3) +# +# To build an executable binary file from some source files and +# dependent libraries: +# +# cc_binary(example SRCS main.cc something.cc DEPS example1 example2) +# +# To build an executable binary file using NVCC, use the nv_ prefixed +# version: +# +# nv_binary(example SRCS main.cc something.cu DEPS example1 example2) +# +# To build a unit test binary, which is an executable binary with +# GoogleTest linked: +# +# cc_test(example_test SRCS example_test.cc DEPS example) +# +# To build a unit test binary using NVCC, use the nv_ prefixed version: +# +# nv_test(example_test SRCS example_test.cu DEPS example) # -# cmake_parse_arguments can help us to achieve this goal. -# https://cmake.org/cmake/help/v3.0/module/CMakeParseArguments.html +# It is pretty often that executable and test binaries depend on +# pre-defined external libaries like glog and gflags defined in +# /cmake/external/*.cmake: # +# cc_test(example_test SRCS example_test.cc DEPS example glog gflags) if(NOT APPLE) find_package(Threads REQUIRED) link_libraries(${CMAKE_THREAD_LIBS_INIT}) endif(NOT APPLE) -# cc_library parses tensor.cc and figures out that target also depend on tensor.h. -# cc_library(tensor -# SRCS -# tensor.cc -# DEPS -# variant) +function(merge_static_libs TARGET_NAME) + set(libs ${ARGN}) + list(REMOVE_DUPLICATES libs) + + # First get the file names of the libraries to be merged + foreach(lib ${libs}) + get_target_property(libtype ${lib} TYPE) + if(NOT libtype STREQUAL "STATIC_LIBRARY") + message(FATAL_ERROR "merge_static_libs can only process static libraries") + endif() + set(libfiles ${libfiles} $) + endforeach() + + if(APPLE) # Use OSX's libtool to merge archives + add_custom_target(${TARGET_NAME}_archive + COMMAND libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles} + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${libs} + ) + add_library(${TARGET_NAME} STATIC IMPORTED GLOBAL) + set_property(TARGET ${TARGET_NAME} PROPERTY + IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a") + add_dependencies(${TARGET_NAME} ${TARGET_NAME}_archive) + else() # general UNIX: use "ar" to extract objects and re-add to a common lib + foreach(lib ${libs}) + set(objlistfile ${lib}.objlist) # list of objects in the input library + set(objdir ${lib}.objdir) + + add_custom_command(OUTPUT ${objdir} + COMMAND ${CMAKE_COMMAND} -E make_directory ${objdir}) + + add_custom_command(OUTPUT ${objlistfile} + COMMAND ${CMAKE_AR} -x "$" + COMMAND ${CMAKE_AR} -t "$" > ../${objlistfile} + DEPENDS ${lib} ${objdir} + WORKING_DIRECTORY ${objdir}) + + # Empty dummy source file that goes into merged library + set(mergebase ${lib}.mergebase.c) + add_custom_command(OUTPUT ${mergebase} + COMMAND ${CMAKE_COMMAND} -E touch ${mergebase} + DEPENDS ${objlistfile}) + + list(APPEND mergebases "${mergebase}") + endforeach() + + # We need a target for the output merged library + add_library(${TARGET_NAME} STATIC ${mergebases}) + set(outlibfile "$") + + foreach(lib ${libs}) + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND ${CMAKE_AR} ru ${outlibfile} @"../${objlistfile}" + WORKING_DIRECTORY ${objdir}) + endforeach() + + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND ${CMAKE_RANLIB} ${outlibfile}) + endif() +endfunction(merge_static_libs) + function(cc_library TARGET_NAME) - set(options OPTIONAL) + set(options STATIC static SHARED shared) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(cc_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if (${cc_library_OPTIONAL} STREQUAL "SHARED") - add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) - else() - add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) - endif() - if (cc_library_DEPS) - add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) - endif() + if (cc_library_SRCS) + if (cc_library_SHARED OR cc_library_shared) # build *.so + add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) + else() + add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) + endif() + if (cc_library_DEPS) + add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) + endif() + else(cc_library_SRCS) + if (cc_library_DEPS) + merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) + else() + message(FATAL "Please specify source file or library in cc_library.") + endif() + endif(cc_library_SRCS) endfunction(cc_library) -# cc_binary parses tensor.cc and figures out that target also depend on tensor.h. -# cc_binary(tensor -# SRCS -# tensor.cc) function(cc_binary TARGET_NAME) - set(options OPTIONAL) + set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(cc_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -71,13 +179,6 @@ function(cc_binary TARGET_NAME) endif() endfunction(cc_binary) -# The dependency to target tensor implies that if any of -# tensor{.h,.cc,_test.cc} is changed, tensor_test need to be re-built. -# cc_test(tensor_test -# SRCS -# tensor_test.cc -# DEPS -# tensor) function(cc_test TARGET_NAME) if(WITH_TESTING) set(options "") @@ -91,28 +192,28 @@ function(cc_test TARGET_NAME) endif() endfunction(cc_test) -# Suppose that ops.cu includes global functions that take Tensor as -# their parameters, so ops depend on tensor. This implies that if -# any of tensor.{h.cc}, ops.{h,cu} is changed, ops need to be re-built. -# nv_library(ops -# SRCS -# ops.cu -# DEPS -# tensor) function(nv_library TARGET_NAME) if (WITH_GPU) - set(options OPTIONAL) + set(options STATIC static SHARED shared) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(nv_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if (${nv_library_OPTIONAL} STREQUAL "SHARED") - cuda_add_library(${TARGET_NAME} SHARED ${nv_library_SRCS}) - else() - cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) - endif() - if (nv_library_DEPS) - add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) - endif() + if(nv_library_SRCS) + if (nv_library_SHARED OR nv_library_shared) # build *.so + cuda_add_library(${TARGET_NAME} SHARED ${nv_library_SRCS}) + else() + cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) + endif() + if (nv_library_DEPS) + add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) + endif() + else(nv_library_SRCS) + if (nv_library_DEPS) + merge_static_libs(${TARGET_NAME} ${nv_library_DEPS}) + else() + message(FATAL "Please specify source file or library in nv_library.") + endif() + endif(nv_library_SRCS) endif() endfunction(nv_library) @@ -130,13 +231,6 @@ function(nv_binary TARGET_NAME) endif() endfunction(nv_binary) -# The dependency to target tensor implies that if any of -# ops{.h,.cu,_test.cu} is changed, ops_test need to be re-built. -# nv_test(ops_test -# SRCS -# ops_test.cu -# DEPS -# ops) function(nv_test TARGET_NAME) if (WITH_GPU AND WITH_TESTING) set(options "") diff --git a/cmake/util.cmake b/cmake/util.cmake index 8c9143462227e7081142f6be250b1a45e4b6d51b..87ad9d91d8701c56255c1e7f224764998df634a7 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -84,6 +84,7 @@ function(link_paddle_exe TARGET_NAME) paddle_parameter paddle_proto paddle_cuda + paddle_optimizer ${EXTERNAL_LIBS} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS} diff --git a/go/pserver/cclient/CMakeLists.txt b/go/pserver/cclient/CMakeLists.txt index 7967af51ee9a94c9e40bf6403fe819ff462d9219..fff7ae78582732c1b7af7a757c340804e91316d6 100644 --- a/go/pserver/cclient/CMakeLists.txt +++ b/go/pserver/cclient/CMakeLists.txt @@ -11,13 +11,4 @@ include(flags) go_library(paddle_pserver_cclient STATIC) -if(PROJ_ROOT) - add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a - COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/libpaddle_pserver_cclient.h ${PROJ_ROOT}/paddle/trainer/ - COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/libpaddle_pserver_cclient.a ${PROJ_ROOT}/paddle/trainer/ - WORKING_DIRECTORY ${PROJ_ROOT}/paddle - DEPENDS paddle_pserver_cclient) - add_custom_target(paddle_pserver_cclient_lib ALL DEPENDS ${PROJ_ROOT}/paddle/trainer/libpaddle_pserver_cclient.a) -endif(PROJ_ROOT) - add_subdirectory(test) diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 47ca1833967ee705d6558b1dad06a6335b30f03a..573bd937a351a6f308974e14f3bc92cbe1b541bc 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(gserver) add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) +add_subdirectory(optimizer) add_subdirectory(strings) # Do not build go directory until go cmake is working smoothly. @@ -19,8 +20,8 @@ find_package(Boost QUIET) if(Boost_FOUND) include_directories(${Boost_INCLUDE_DIRS}) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}) - add_subdirectory(majel) + add_subdirectory(platform) + add_subdirectory(framework) endif() if(WITH_C_API) diff --git a/paddle/api/CMakeLists.txt b/paddle/api/CMakeLists.txt index c9433a38de4d005ebe229c55916401a5f82e9ef3..f2315e31cc06d8b5fea7a9fd203a697bac603a90 100644 --- a/paddle/api/CMakeLists.txt +++ b/paddle/api/CMakeLists.txt @@ -16,7 +16,7 @@ set(API_HEADER Internal.h) add_library(paddle_api STATIC ${API_SOURCES}) -add_dependencies(paddle_api gen_proto_cpp paddle_pserver_cclient_lib) +add_dependencies(paddle_api gen_proto_cpp paddle_trainer_lib) INCLUDE(${SWIG_USE_FILE}) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 7565ea51fe3e71bf81a28e6e4b5a2bbdd085798c..5fb3d1c73bc56e921f13aafd27c25224e259b3fe 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -842,7 +842,8 @@ public: int passCount, bool useSparseUpdater); static ParameterUpdater* createNewRemoteUpdater( - OptimizationConfig* config, const std::string pserverSpec); + OptimizationConfig* config, + const std::string pserverSpec) throw(UnsupportError); ~ParameterUpdater(); /** diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index eaf8518ae2beaa93bc40ee944c984d142d2bb951..1aaefdfb8107a2eaa0432211fd7df4f5f12d537f 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -15,7 +15,9 @@ limitations under the License. */ #include "PaddleAPI.h" #include "PaddleAPIPrivate.h" +#ifndef PADDLE_WITHOUT_GOLANG #include "paddle/trainer/NewRemoteParameterUpdater.h" +#endif #include "paddle/trainer/RemoteParameterUpdater.h" #include "paddle/trainer/ThreadParameterUpdater.h" @@ -30,11 +32,16 @@ ParameterUpdater *ParameterUpdater::createLocalUpdater( } ParameterUpdater *ParameterUpdater::createNewRemoteUpdater( - OptimizationConfig *config, const std::string pserverSpec) { + OptimizationConfig *config, + const std::string pserverSpec) throw(UnsupportError) { +#ifndef PADDLE_WITHOUT_GOLANG auto updater = new ParameterUpdater(); updater->m->updater.reset(new paddle::NewRemoteParameterUpdater( config->m->getConfig(), pserverSpec)); return updater; +#else + throw UnsupportError(); +#endif } ParameterUpdater *ParameterUpdater::createRemoteUpdater( diff --git a/paddle/framework/.clang-format b/paddle/framework/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..29282dc87e2c499988c17d90d47d44cd5cf7f115 --- /dev/null +++ b/paddle/framework/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..673cfa19ac35116288a2481b85858b6f88f3378e --- /dev/null +++ b/paddle/framework/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library(ddim SRCS ddim.cc) +cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) + +nv_test(dim_test SRCS dim_test.cu DEPS ddim) diff --git a/paddle/majel/ddim.cc b/paddle/framework/ddim.cc similarity index 94% rename from paddle/majel/ddim.cc rename to paddle/framework/ddim.cc index f32408ed53074234873ec0ea8ee7f4e449e5e908..3f949a6595ea326b97ac567daf9b35a68c8cf7f8 100644 --- a/paddle/majel/ddim.cc +++ b/paddle/framework/ddim.cc @@ -1,6 +1,7 @@ -#include "paddle/majel/ddim.h" +#include "paddle/framework/ddim.h" -namespace majel { +namespace paddle { +namespace framework { ///@cond HIDDEN @@ -66,7 +67,7 @@ DDim make_ddim(const std::vector& dims) { ///@cond HIDDEN // XXX For some reason, putting this in an anonymous namespace causes errors class DynamicMutableIndexer : public boost::static_visitor { -public: + public: DynamicMutableIndexer(int idx) : idx_(idx) {} template @@ -74,12 +75,12 @@ public: return dim[idx_]; } -private: + private: int idx_; }; class DynamicConstIndexer : public boost::static_visitor { -public: + public: DynamicConstIndexer(int idx) : idx_(idx) {} template @@ -87,7 +88,7 @@ public: return dim[idx_]; } -private: + private: int idx_; }; @@ -213,10 +214,11 @@ struct DDimPrinter : boost::static_visitor { ///\endcond -std::ostream& operator<<(std::ostream& os, const majel::DDim& ddim) { +std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDimPrinter printer(os); boost::apply_visitor(printer, ddim); return os; } -} // namespace majel +} // namespace framework +} // namespace paddle diff --git a/paddle/majel/ddim.h b/paddle/framework/ddim.h similarity index 79% rename from paddle/majel/ddim.h rename to paddle/framework/ddim.h index 7be756f8c098ba5aa3a5ff4380c90f4b90a55bb7..223c4180bee45e21547364441476b27051daca56 100644 --- a/paddle/majel/ddim.h +++ b/paddle/framework/ddim.h @@ -5,20 +5,14 @@ #include #include -#include "paddle/majel/dim.h" +#include "paddle/framework/dim.h" -namespace majel { +namespace paddle { +namespace framework { namespace { -typedef boost::variant, - Dim<2>, - Dim<3>, - Dim<4>, - Dim<5>, - Dim<6>, - Dim<7>, - Dim<8>, - Dim<9>> +typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, + Dim<8>, Dim<9>> DDimVar; } @@ -95,14 +89,15 @@ ssize_t product(const DDim& ddim); int arity(const DDim& ddim); -std::ostream& operator<<(std::ostream&, const majel::DDim&); +std::ostream& operator<<(std::ostream&, const DDim&); -} // namespace majel +} // namespace framework +} // namespace paddle namespace boost { template -T get(const majel::DDim& in) { +T get(const paddle::framework::DDim& in) { return boost::get(in.var); } diff --git a/paddle/majel/ddim_test.cc b/paddle/framework/ddim_test.cc similarity index 59% rename from paddle/majel/ddim_test.cc rename to paddle/framework/ddim_test.cc index a5b8a7c4d26740c1c4169547e76a0cf5558facc0..e5c84d7abe9d476f285c8c5cd904d2e570eb0e4f 100644 --- a/paddle/majel/ddim_test.cc +++ b/paddle/framework/ddim_test.cc @@ -4,18 +4,18 @@ #include #include "gtest/gtest.h" -#include "paddle/majel/ddim.h" +#include "paddle/framework/ddim.h" TEST(DDim, Equality) { // construct a DDim from an initialization list - majel::DDim ddim = majel::make_ddim({9, 1, 5}); + paddle::framework::DDim ddim = paddle::framework::make_ddim({9, 1, 5}); EXPECT_EQ(ddim[0], 9); EXPECT_EQ(ddim[1], 1); EXPECT_EQ(ddim[2], 5); // construct a DDim from a vector std::vector vec({9, 1, 5}); - majel::DDim vddim = majel::make_ddim(vec); + paddle::framework::DDim vddim = paddle::framework::make_ddim(vec); EXPECT_EQ(ddim[0], 9); EXPECT_EQ(ddim[1], 1); EXPECT_EQ(ddim[2], 5); @@ -23,43 +23,43 @@ TEST(DDim, Equality) { // mutate a DDim ddim[1] = 2; EXPECT_EQ(ddim[1], 2); - majel::set(ddim, 0, 6); - EXPECT_EQ(majel::get(ddim, 0), 6); + paddle::framework::set(ddim, 0, 6); + EXPECT_EQ(paddle::framework::get(ddim, 0), 6); // vectorize a DDim - std::vector res_vec = majel::vectorize(vddim); + std::vector res_vec = paddle::framework::vectorize(vddim); EXPECT_EQ(res_vec[0], 9); EXPECT_EQ(res_vec[1], 1); EXPECT_EQ(res_vec[2], 5); - majel::Dim<3> d(3, 2, 1); - res_vec = majel::vectorize(majel::DDim(d)); + paddle::framework::Dim<3> d(3, 2, 1); + res_vec = paddle::framework::vectorize(paddle::framework::DDim(d)); EXPECT_EQ(res_vec[0], 3); EXPECT_EQ(res_vec[1], 2); EXPECT_EQ(res_vec[2], 1); // add two DDims - majel::DDim ddim_sum = ddim + vddim; + paddle::framework::DDim ddim_sum = ddim + vddim; EXPECT_EQ(ddim_sum[0], 15); EXPECT_EQ(ddim_sum[1], 3); EXPECT_EQ(ddim_sum[2], 10); // multiply two DDims - majel::DDim ddim_mul = ddim * vddim; + paddle::framework::DDim ddim_mul = ddim * vddim; EXPECT_EQ(ddim_mul[0], 54); EXPECT_EQ(ddim_mul[1], 2); EXPECT_EQ(ddim_mul[2], 25); // arity of a DDim - EXPECT_EQ(majel::arity(ddim), 3); + EXPECT_EQ(paddle::framework::arity(ddim), 3); // product of a DDim - EXPECT_EQ(majel::product(vddim), 45); + EXPECT_EQ(paddle::framework::product(vddim), 45); } TEST(DDim, Print) { // print a DDim std::stringstream ss; - majel::DDim ddim = majel::make_ddim({2, 3, 4}); + paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 3, 4}); ss << ddim; EXPECT_EQ("2, 3, 4", ss.str()); } diff --git a/paddle/majel/dim.h b/paddle/framework/dim.h similarity index 96% rename from paddle/majel/dim.h rename to paddle/framework/dim.h index c4b0c6aea683384d4657dd5db6f419b9e1108704..bcde291d12d429a3f2cd41fa6d0ee606c7c9c92f 100644 --- a/paddle/majel/dim.h +++ b/paddle/framework/dim.h @@ -5,10 +5,11 @@ #include #include -#include "paddle/majel/detail/cuda_assert.h" -#include "paddle/majel/detail/hostdevice.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" -namespace majel { +namespace paddle { +namespace framework { // Statically sized, statically indexed dimension template @@ -74,7 +75,7 @@ struct Dim<1> { throw std::invalid_argument("Index out of range."); } #else - MAJEL_ASSERT(idx < size.head); + PADDLE_ASSERT(idx < size.head); #endif } @@ -131,7 +132,7 @@ HOSTDEVICE int& indexer(Dim& dim, int idx) { throw std::invalid_argument("Tried to access a negative dimension"); } #else - MAJEL_ASSERT(idx >= 0); + PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; @@ -146,7 +147,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) { throw std::invalid_argument("Invalid index"); } #else - MAJEL_ASSERT(idx == 0); + PADDLE_ASSERT(idx == 0); #endif return dim.head; } @@ -158,7 +159,7 @@ HOSTDEVICE int indexer(const Dim& dim, int idx) { throw std::invalid_argument("Tried to access a negative dimension"); } #else - MAJEL_ASSERT(idx >= 0); + PADDLE_ASSERT(idx >= 0); #endif if (idx == 0) { return dim.head; @@ -173,7 +174,7 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) { throw std::invalid_argument("Invalid index"); } #else - MAJEL_ASSERT(idx == 0); + PADDLE_ASSERT(idx == 0); #endif return dim.head; } @@ -411,7 +412,7 @@ HOSTDEVICE Dim make_dim(Args... idxes) { // XXX For some reason, overloading fails to resolve this correctly template typename std::enable_if<(i > 1), std::ostream&>::type operator<<( - std::ostream& os, const majel::Dim& d) { + std::ostream& os, const Dim& d) { os << d.head << ", " << d.tail; return os; } @@ -420,7 +421,7 @@ typename std::enable_if<(i > 1), std::ostream&>::type operator<<( // XXX I wish this could be an overload instead of a template template typename std::enable_if<(i == 1), std::ostream&>::type operator<<( - std::ostream& os, const majel::Dim& d) { + std::ostream& os, const Dim& d) { os << d.head; return os; } @@ -448,4 +449,5 @@ HOSTDEVICE Dim linear_to_dimension(int linear_index, Dim extents) { return result; } -} // namespace majel +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/dim_test.cu b/paddle/framework/dim_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..809bf04826637195425a32c054c94e00ef940df9 --- /dev/null +++ b/paddle/framework/dim_test.cu @@ -0,0 +1,128 @@ +#include +#include + +#include "paddle/framework/dim.h" +#include "gtest/gtest.h" + +__global__ void test(paddle::framework::Dim<2>* o) { + o[0] = paddle::framework::make_dim(5, 6); +} + +__global__ void dyn_idx_gpu(int* o) { + auto d = paddle::framework::make_dim(5, 6); + o[0] = d[1]; +} + +TEST(Dim, Equality) { + // construct a Dim on the CPU + auto a = paddle::framework::make_dim(3, 4); + EXPECT_EQ(paddle::framework::get<0>(a), 3); + EXPECT_EQ(paddle::framework::get<1>(a), 4); + + // construct a Dim on the GPU + thrust::device_vector> t(2); + test<<<1,1>>>(thrust::raw_pointer_cast(t.data())); + a = t[0]; + EXPECT_EQ(paddle::framework::get<0>(a), 5); + EXPECT_EQ(paddle::framework::get<1>(a), 6); + + // linearization + auto b = paddle::framework::make_dim(7, 8); + EXPECT_EQ(paddle::framework::linearize(a, b), 83); + + // product + EXPECT_EQ(paddle::framework::product(a), 30); + + // mutate a Dim + paddle::framework::get<1>(b) = 10; + EXPECT_EQ(paddle::framework::get<0>(b), 7); + EXPECT_EQ(paddle::framework::get<1>(b), 10); + + // dynamic access + paddle::framework::get(b, 0) = 8; + b[1] = 11; + EXPECT_EQ(paddle::framework::get<0>(b), 8); + EXPECT_EQ(paddle::framework::get<1>(b), 11); + EXPECT_EQ(paddle::framework::get(b, 0), 8); + EXPECT_EQ(b[1], 11); + + // dynamic access on GPU + thrust::device_vector r(1); + dyn_idx_gpu<<<1,1>>>(thrust::raw_pointer_cast(r.data())); + int res = r[0]; + EXPECT_EQ(res, 6); + + // ex_prefix_mul + paddle::framework::Dim<3> c = paddle::framework::ex_prefix_mul(paddle::framework::Dim<3>(3, 4, 5)); + EXPECT_EQ(paddle::framework::get<0>(c), 1); + EXPECT_EQ(paddle::framework::get<1>(c), 3); + EXPECT_EQ(paddle::framework::get<2>(c), 12); + + // contiguous_strides + c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10)); + EXPECT_EQ(paddle::framework::get<0>(c), 1); + EXPECT_EQ(paddle::framework::get<1>(c), 0); + EXPECT_EQ(paddle::framework::get<2>(c), 10); + c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1)); + EXPECT_EQ(paddle::framework::get<0>(c), 1); + EXPECT_EQ(paddle::framework::get<1>(c), 10); + EXPECT_EQ(paddle::framework::get<2>(c), 0); + c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10)); + EXPECT_EQ(paddle::framework::get<0>(c), 0); + EXPECT_EQ(paddle::framework::get<1>(c), 1); + EXPECT_EQ(paddle::framework::get<2>(c), 10); + c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4)); + EXPECT_EQ(paddle::framework::get<0>(c), 1); + EXPECT_EQ(paddle::framework::get<1>(c), 2); + EXPECT_EQ(paddle::framework::get<2>(c), 6); + + // generate from an index + auto size = paddle::framework::make_dim(4, 5, 2); + c = paddle::framework::Dim<3>(14, size); + EXPECT_EQ(paddle::framework::get<0>(c), 2); + EXPECT_EQ(paddle::framework::get<1>(c), 3); + EXPECT_EQ(paddle::framework::get<2>(c), 0); + c = paddle::framework::Dim<3>(25, size); + EXPECT_EQ(paddle::framework::get<0>(c), 1); + EXPECT_EQ(paddle::framework::get<1>(c), 1); + EXPECT_EQ(paddle::framework::get<2>(c), 1); +} + +TEST(Dim, Bool) { + auto a = paddle::framework::make_dim(3, 4); + auto b = paddle::framework::make_dim(5, 6); + auto c = paddle::framework::make_dim(3, 4); + + // in_bounds check + EXPECT_TRUE(paddle::framework::contained(a, b)); + EXPECT_FALSE(paddle::framework::contained(b, a)); + + // comparison + EXPECT_TRUE(a == a); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a == c); + + // contiguous check + int x = 4, y = 5, z = 2; + paddle::framework::Dim<3> sizef(x, y, z); + paddle::framework::Dim<3> stridea(1, x, x*y); + paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y); + paddle::framework::Dim<3> stridec(1, x, 2*x*y); + EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea)); + EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb)); + EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec)); +} + +TEST(Dim, Print) { + { + std::stringstream ss; + auto a = paddle::framework::make_dim(2, 3); + ss << a; + EXPECT_EQ(ss.str(), "2, 3"); + } + { + std::stringstream ss; + ss << paddle::framework::make_dim(8); + EXPECT_EQ(ss.str(), "8"); + } +} diff --git a/paddle/majel/README.md b/paddle/framework/tensor.md similarity index 100% rename from paddle/majel/README.md rename to paddle/framework/tensor.md diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 65b9d1d53f9210b08cdc8bbd9d93b03305e582e4..bb4f48364b9b454af7d37fe4d3c340666e53285c 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -68,14 +68,12 @@ public: numOutputs_ = 1; } - virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {} - // input can be INPUT and INPUT_GRAD // filter can be FILTER and FILTER_GRAD // output can be OUTPUT and OUTPUT_GRAD - void check(const TensorShape& input, - const TensorShape& filter, - const TensorShape& output) { + void checkShape(const TensorShape& input, + const TensorShape& filter, + const TensorShape& output) { // inputs and outputs arguments should be 4-dimensional. CHECK_EQ(input.ndims(), (size_t)4); CHECK_EQ(output.ndims(), (size_t)4); diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index c7a57801ed6098260af5ba22be82ac4ea7c2e601..a40e5d9d2e76605525f0956445fc43c693933cf8 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -117,15 +117,23 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // TODO(hedaoyuan): Need to define some index macros, // to avoid useing 0 and 1. const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - check(input, filter, output); real beta; if (outputs[0].getArgType() == ADD_TO) { @@ -209,16 +217,24 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& input = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); // Since the implementation of Col2ImFunctor is ADD_TO, // this function only supports ADD_TO mode. CHECK_EQ(outputs[0].getArgType(), ADD_TO); const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& input = outputs[0].shape(); - check(input, filter, output); size_t batchSize = input[0]; size_t inputChannels = input[1]; @@ -295,13 +311,21 @@ public: ConvFunctionBase::init(config); } + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { + const TensorShape& output = inputs[0].shape(); + const TensorShape& input = inputs[1].shape(); + const TensorShape& filter = outputs[0].shape(); + checkShape(input, filter, output); + } + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); const TensorShape& filter = outputs[0].shape(); - check(input, filter, output); real beta; if (outputs[0].getArgType() == ADD_TO) { diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp index 1d204f99e0e127688eeda28b46715a37c1100c4e..4348f0f775e9442c50a3c45b9a8e6dad5c6b198d 100644 --- a/paddle/function/NaiveConvOp.cpp +++ b/paddle/function/NaiveConvOp.cpp @@ -54,8 +54,8 @@ public: T inValue; const int inH = inStartH + fH; const int inW = inStartW + fW; - if ((inH >= 0 && inH < inputHeight) && - (inW >= 0 && inW < inputWidth)) { + if ((inH >= 0 && inH < (int)inputHeight) && + (inW >= 0 && inW < (int)inputWidth)) { size_t offsetInput = batch * inputChannels * inputHeight * inputWidth + inC * inputHeight * inputWidth + inH * inputWidth + inW; @@ -90,14 +90,19 @@ public: ConvFunctionBase::init(config); } - void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - CHECK_EQ(numInputs_, inputs.size()); - CHECK_EQ(numOutputs_, outputs.size()); + virtual void check(const BufferArgs& inputs, + const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - check(input, filter, output); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + check(inputs, outputs); size_t batchSize = inputs[0].shape()[0]; size_t inputChannels = inputs[0].shape()[1]; diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index 01158d1dce8d711c67b1ecf29bb644e42ccf6ff5..3e930380226bce58cc90704b4c4cfa36e9f70968 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -214,7 +214,6 @@ void RecurrentGradientMachine::init( inFrameLines_[i].linkName = subModelConfig->in_links(i).link_name(); inFrameLines_[i].inLayer = rootNetwork_->getLayer(subModelConfig->in_links(i).layer_name()); - inFrameLines_[i].hasSubseq = subModelConfig->in_links(i).has_subseq(); } outFrameLines_.resize(subModelConfig->out_links_size()); @@ -241,11 +240,8 @@ void RecurrentGradientMachine::init( rootNetwork_->getLayer(memoryConfig.boot_layer_name()); LayerConfig scatterConfig = *agentConfig; - memoryFrameLines_[i].is_sequence = memoryConfig.is_sequence(); memoryFrameLines_[i].rootAgent.reset( - memoryConfig.is_sequence() - ? new SequenceScatterAgentLayer(scatterConfig) - : new ScatterAgentLayer(scatterConfig)); + new ScatterAgentLayer(scatterConfig)); memoryFrameLines_[i].rootAgent->init(LayerMap(), parameterMap_); memoryFrameLines_[i].bootLayer = memoryFrameLines_[i].rootAgent; @@ -267,9 +263,7 @@ void RecurrentGradientMachine::init( if (subModelConfig->has_generator()) { memoryFrameLines_[i].scatterAgents.resize(2); for (auto& agent : memoryFrameLines_[i].scatterAgents) { - agent.reset(memoryConfig.is_sequence() - ? new SequenceScatterAgentLayer(*agentConfig) - : new ScatterAgentLayer(*agentConfig)); + agent.reset(new ScatterAgentLayer(*agentConfig)); agent->init(LayerMap(), parameterMap_); } } @@ -297,8 +291,6 @@ void RecurrentGradientMachine::init( if (subModelConfig->evaluator_names_size() > 0) { evaluator_.reset(frames_[0]->makeEvaluator()); } - - targetInfoInlinkId_ = subModelConfig->target_inlinkid(); } void RecurrentGradientMachine::resizeOrCreateFrames(int numFrames) { @@ -376,108 +368,102 @@ void RecurrentGradientMachine::prefetch(const std::vector& inArgs) { LOG(FATAL) << "should not use this function"; } -void RecurrentGradientMachine::forward(const std::vector& inArgs, - std::vector* outArgs, - PassType passType) { - if (inFrameLines_.empty() && passType == PASS_TEST) { - generateSequence(); - return; - } // else forward.. - - const Argument& input = inFrameLines_[0].inLayer->getOutput(); - CHECK(input.sequenceStartPositions); - int batchSize = input.getBatchSize(); - size_t numSequences = input.getNumSequences(); - const int* starts = input.sequenceStartPositions->getData(false); - bool hasSubseq = input.hasSubseq(); - - // In case of !hasSubseq or targetInfoInlinkId_ == -1, all inlinks share the - // same inframe info - bool shareInlinkInfo = !hasSubseq || targetInfoInlinkId_ == -1; - - // Defaultly, share info with the first inlink - if (shareInlinkInfo) { - targetInfoInlinkId_ = 0; - } - - // check hasSubseq in both config and input are the same - CHECK_EQ(hasSubseq, inFrameLines_[0].hasSubseq); - - CHECK_EQ(starts[numSequences], batchSize); - CHECK(input.sequenceStartPositions); - - // check other inputs has same sequence length and start - for (size_t i = 1; i < inFrameLines_.size(); ++i) { - const Argument& input1 = inFrameLines_[i].inLayer->getOutput(); - CHECK_EQ((size_t)input1.getNumSequences(), numSequences); - // check all inputs should have same hasSubseq flag - CHECK_EQ(input.hasSubseq(), inFrameLines_[0].hasSubseq); - - // if shareInlinkInfo, checks: - // 1. all inlinks have same number of total tokens - // 2. all inlinks have same number of tokens for each sentence of each - // sample. If hasSubseq, one sample has multiple sentence, else, one - // sample is one sentence - if (shareInlinkInfo) { - CHECK_EQ(input1.getBatchSize(), batchSize); - CHECK(std::equal(starts, - starts + numSequences + 1, - input1.sequenceStartPositions->getData(false))); +void RecurrentGradientMachine::checkInputConsistency( + int inlinkId, const std::vector& seqInfo) { + if (commonSeqInfo_.empty()) { + commonSeqInfo_.resize(seqInfo.size()); + for (size_t i = 0; i < seqInfo.size(); ++i) { + commonSeqInfo_[i].topLevelLength = seqInfo[i].topLevelLength; + commonSeqInfo_[i].seqId = seqInfo[i].seqId; + } + } else { + CHECK_EQ(commonSeqInfo_.size(), seqInfo.size()) + << " RecurrentGroup " << subModelName_ << " input " << inlinkId + << " has mismatched number of sequences"; + for (size_t i = 0; i < seqInfo.size(); ++i) { + CHECK_EQ(commonSeqInfo_[i].topLevelLength, seqInfo[i].topLevelLength) + << " RecurrentGroup " << subModelName_ << " input " << inlinkId + << " has mismatched sequence length"; + CHECK_EQ(commonSeqInfo_[i].seqId, seqInfo[i].seqId) + << " RecurrentGroup " << subModelName_ << " input " << inlinkId + << " has mismatched sequence length"; } } +} - if (hasSubseq) { - CHECK(input.subSequenceStartPositions); - size_t numSubSequences = input.getNumSubSequences(); - const int* subStarts = input.subSequenceStartPositions->getData(false); - CHECK_EQ(subStarts[numSubSequences], batchSize); - // if hasSubseq, check other inputs has same sub-sequence and sub-start - for (size_t i = 1; i < inFrameLines_.size(); ++i) { - const Argument& input1 = inFrameLines_[i].inLayer->getOutput(); - CHECK_EQ((size_t)input1.getNumSubSequences(), numSubSequences); - if (shareInlinkInfo) { - CHECK(std::equal(subStarts, - subStarts + numSubSequences + 1, - input1.subSequenceStartPositions->getData(false))); - } +void RecurrentGradientMachine::calcNumSequencesAtEachStep() { + int numSequences = commonSeqInfo_.size(); + numSeqs_.resize(maxSequenceLength_); + for (int i = 0; i < numSequences; ++i) { + for (int j = 0; j < commonSeqInfo_[i].topLevelLength; ++j) { + numSeqs_[j] = i + 1; } } +} +void RecurrentGradientMachine::reorganizeInput(PassType passType) { info_.clear(); info_.resize(inFrameLines_.size()); + commonSeqInfo_.clear(); seqInfos_.clear(); seqInfos_.resize(inFrameLines_.size()); + for (size_t i = 0; i < inFrameLines_.size(); i++) { + const Argument& input = inFrameLines_[i].inLayer->getOutput(); + if (!input.hasSeq()) { + continue; + } + input.getSeqInfo(&seqInfos_[i]); + checkInputConsistency(i, seqInfos_[i]); + } + CHECK(!commonSeqInfo_.empty()) + << "At least one input needs to be sequence or subsequence"; + maxSequenceLength_ = commonSeqInfo_[0].topLevelLength; + + calcNumSequencesAtEachStep(); + + for (size_t i = 0; i < inFrameLines_.size(); ++i) { + const Argument& input = inFrameLines_[i].inLayer->getOutput(); + if (!input.hasSeq()) { + seqInfos_[i] = commonSeqInfo_; + } + createInFrameInfo(i, input, passType); + } + { AsyncGpuBlock asyncGpuBlock; - // if shareInlinkInfo, only calculate info of the first inlink - // else, calculate info for each inlink - if (shareInlinkInfo) { - input.getSeqInfo(&seqInfos_[0]); - maxSequenceLength_ = seqInfos_[0][0].topLevelLength; - createInFrameInfo(0, input, passType); - } else { - for (size_t i = 0; i < inFrameLines_.size(); i++) { - const Argument& input1 = inFrameLines_[i].inLayer->getOutput(); - input1.getSeqInfo(&seqInfos_[i]); - maxSequenceLength_ = seqInfos_[i][0].topLevelLength; - createInFrameInfo(i, input1, passType); - } - } // inFrameLine select rows in real layer one time for (size_t i = 0; i < inFrameLines_.size(); i++) { - int curInlinkId = shareInlinkInfo ? 0 : i; selectRowsOneTime(inFrameLines_[i].inLayer, - info_[curInlinkId].allIds, + info_[i].allIds, &(inFrameLines_[i].outArg), passType); } } - resizeOrCreateFrames(maxSequenceLength_); - resizeBootFrame(numSequences); +} + +void RecurrentGradientMachine::reorganizeOutput(PassType passType) { + calcSequenceStartPositions(); + for (size_t i = 0; i < outFrameLines_.size(); ++i) { + Info info; + auto& outFrameLine = outFrameLines_[i]; + ICpuGpuVectorPtr sequenceStartPositions; + ICpuGpuVectorPtr subSequenceStartPositions; + createOutFrameInfo( + outFrameLine, info, sequenceStartPositions, subSequenceStartPositions); + auto gatherAgent = + dynamic_cast(outFrameLine.agentLayer.get()); + CHECK_NOTNULL(gatherAgent); + gatherAgent->copyIdAndSequenceInfo(sequenceStartPositions, + subSequenceStartPositions, + info.allIds, + info.idIndex); + } +} +void RecurrentGradientMachine::connectFrames(PassType passType) { for (auto& memoryFrameLine : memoryFrameLines_) { if (memoryFrameLine.rootAgent) { auto scatterAgent = @@ -487,8 +473,9 @@ void RecurrentGradientMachine::forward(const std::vector& inArgs, memoryFrameLine.outArg, memoryFrameLine.allIds, /* idIndex */ 0, - memoryFrameLine.allIds->getSize()); - if (memoryFrameLine.is_sequence) { // memoryConfig is sequence + memoryFrameLine.allIds->getSize(), + /* handleBackward */ true); + if (memoryFrameLine.sequenceStartPositions) { int size = memoryFrameLine.sequenceStartPositions->getSize(); scatterAgent->setSequenceStartPositions( memoryFrameLine.sequenceStartPositions, @@ -501,28 +488,26 @@ void RecurrentGradientMachine::forward(const std::vector& inArgs, for (auto& outFrameLine : outFrameLines_) { auto gatherAgent = dynamic_cast(outFrameLine.agentLayer.get()); - CHECK_NOTNULL(gatherAgent); - gatherAgent->copyIdAndSequenceInfo(input, - info_[targetInfoInlinkId_].allIds, - info_[targetInfoInlinkId_].idIndex); + gatherAgent->clearRealLayers(); } - for (int i = 0; i < maxSequenceLength_; ++i) { - int idSize = 0; // connect in_links for (size_t j = 0; j < inFrameLines_.size(); ++j) { - Info& info = info_[shareInlinkInfo ? 0 : j]; + Info& info = info_[j]; // idSize denotes the sum number of tokens in each length i - idSize = info.idIndex[i + 1] - info.idIndex[i]; + int idIndex = info.idIndex.empty() ? 0 : info.idIndex[i]; + int idSize = info.idIndex.empty() ? numSeqs_[i] + : info.idIndex[i + 1] - info.idIndex[i]; InFrameLine inFrameLine = inFrameLines_[j]; auto scatterAgent = dynamic_cast(inFrameLine.agents[i].get()); scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer, inFrameLine.outArg, info.allIds, - info.idIndex[i], - idSize); - if (hasSubseq) { + idIndex, + idSize, + i == 0); + if (info.sequenceStartPositions) { // size: the length of subsequence int size = info.seqStartPosIndex[i + 1] - info.seqStartPosIndex[i]; scatterAgent->setSequenceStartPositions( @@ -536,11 +521,6 @@ void RecurrentGradientMachine::forward(const std::vector& inArgs, dynamic_cast(outFrameLine.agentLayer.get()); gatherAgent->addRealLayer(outFrameLine.frames[i]); } - // connect memory links - // Adopt info_[0].idIndex because seq which has_subseq=True - // doesn't support Memory with !hasSubseq bootlayer; - // And inlinks that !hasSubSeq must have same inlink length. - idSize = info_[0].idIndex[i + 1] - info_[0].idIndex[i]; for (auto& memoryFrameLine : memoryFrameLines_) { NeuralNetwork::connect( memoryFrameLine.agents[i], @@ -548,6 +528,28 @@ void RecurrentGradientMachine::forward(const std::vector& inArgs, numSeqs_[i] /*height of agent*/); } } +} + +void RecurrentGradientMachine::forward(const std::vector& inArgs, + std::vector* outArgs, + PassType passType) { + /* inArgs and outArgs are not used. + The inputs are inFrameLines_[i].inLayer. + The outputs are outFramesLines_[i].agentLayer + */ + + if (inFrameLines_.empty() && passType == PASS_TEST) { + generateSequence(); + return; + } // else forward.. + + reorganizeInput(passType); + int numSequences = commonSeqInfo_.size(); + + resizeOrCreateFrames(maxSequenceLength_); + resizeBootFrame(numSequences); + + connectFrames(passType); REGISTER_TIMER_INFO("RecurrentFwTime", "RecurrentFwTime"); // forward @@ -558,16 +560,12 @@ void RecurrentGradientMachine::forward(const std::vector& inArgs, const std::vector inArgs; std::vector outArgs; frames_[i]->forward(inArgs, &outArgs, passType); - if (hasSubseq) { - for (auto& outFrameLine : outFrameLines_) { - CHECK(outFrameLine.frames[i]->getOutput().sequenceStartPositions) - << "In hierachical RNN, all out links should be from sequences."; - } - } } if (evaluator_ && passType == PASS_TEST) { this->eval(evaluator_.get()); } + + reorganizeOutput(passType); } void RecurrentGradientMachine::backward(const UpdateCallback& callback) { @@ -634,76 +632,228 @@ void RecurrentGradientMachine::removeBeamSearchStatisticsCallbacks() { this->beamSearchStatistics_ = nullptr; } } + +namespace { +void lenToStarts(std::vector& starts) { + int pos = 0; + starts.back() = 0; + for (auto& start : starts) { + int tmp = start; + start = pos; + pos += tmp; + } + starts.back() = pos; +} +} + +void RecurrentGradientMachine::calcSequenceStartPositions() { + std::vector starts(commonSeqInfo_.size() + 1); + for (auto& seqInfo : commonSeqInfo_) { + starts[seqInfo.seqId] = seqInfo.topLevelLength; + } + lenToStarts(starts); + ICpuGpuVector::resizeOrCreate(sequenceStartPositions_, starts.size(), false); + std::copy(starts.begin(), + starts.end(), + sequenceStartPositions_->getMutableData(false)); +} + +void RecurrentGradientMachine::checkOutputConsistency( + OutFrameLine& outFrameLine) { + bool hasSeq = outFrameLine.frames[0]->getOutput().hasSeq(); + for (int i = 0; i < maxSequenceLength_; ++i) { + LayerPtr frame = outFrameLine.frames[i]; + CHECK_EQ(hasSeq, frame->getOutput().hasSeq()); + int numSequences = frame->getOutput().getNumSequences(); + CHECK_EQ(numSeqs_[i], numSequences); + } +} + +void RecurrentGradientMachine::createOutFrameInfo( + OutFrameLine& outFrameLine, + Info& info, + ICpuGpuVectorPtr& sequenceStartPositions, + ICpuGpuVectorPtr& subSequenceStartPositions) { + checkOutputConsistency(outFrameLine); + + if (!outFrameLine.frames[0]->getOutput().hasSeq()) { + createOutFrameInfo_seq( + outFrameLine, info, sequenceStartPositions, subSequenceStartPositions); + } else { + createOutFrameInfo_subseq( + outFrameLine, info, sequenceStartPositions, subSequenceStartPositions); + } +} + +void RecurrentGradientMachine::createOutFrameInfo_seq( + OutFrameLine& outFrameLine, + Info& info, + ICpuGpuVectorPtr& sequenceStartPositions, + ICpuGpuVectorPtr& subSequenceStartPositions) { + std::vector allIds; + info.idIndex.resize(1, 0); // first idIndex = 0 + + const int* starts = sequenceStartPositions_->getData(false); + + for (int i = 0; i < maxSequenceLength_; ++i) { + LayerPtr frame = outFrameLine.frames[i]; + size_t numSequences = frame->getOutput().getNumSequences(); + for (size_t j = 0; j < numSequences; ++j) { + int seqStart = starts[commonSeqInfo_[j].seqId]; + int seqLength = commonSeqInfo_[j].topLevelLength; + allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i) + : (seqStart + i)); + } + info.idIndex.push_back(allIds.size()); + } + sequenceStartPositions = sequenceStartPositions_; + copyScattedId(allIds, &info.allIds, allIds.size()); + CHECK_EQ(info.idIndex.size(), static_cast(maxSequenceLength_ + 1)); +} + +void RecurrentGradientMachine::createOutFrameInfo_subseq( + OutFrameLine& outFrameLine, + Info& info, + ICpuGpuVectorPtr& sequenceStartPositions, + ICpuGpuVectorPtr& subSequenceStartPositions) { + size_t numSequences = commonSeqInfo_.size(); + std::vector allIds; + info.idIndex.resize(1, 0); // first idIndex = 0 + + const int* starts = sequenceStartPositions_->getData(false); + std::vector subStarts(starts[numSequences] + 1); + for (int i = 0; i < maxSequenceLength_; ++i) { + LayerPtr frame = outFrameLine.frames[i]; + size_t numSequences = frame->getOutput().getNumSequences(); + const int* seqStarts = + frame->getOutput().sequenceStartPositions->getData(false); + for (size_t j = 0; j < numSequences; ++j) { + subStarts[starts[commonSeqInfo_[j].seqId] + i] = + seqStarts[j + 1] - seqStarts[j]; + } + } + lenToStarts(subStarts); + + for (int i = 0; i < maxSequenceLength_; ++i) { + LayerPtr frame = outFrameLine.frames[i]; + size_t numSequences = frame->getOutput().getNumSequences(); + for (size_t j = 0; j < numSequences; ++j) { + int pos = starts[commonSeqInfo_[j].seqId] + i; + int subSeqStart = subStarts[pos]; + int subSeqEnd = subStarts[pos + 1]; + for (int k = subSeqStart; k < subSeqEnd; ++k) { + allIds.push_back(k); + } + } + info.idIndex.push_back(allIds.size()); + } + + ICpuGpuVector::resizeOrCreate( + subSequenceStartPositions, subStarts.size(), false); + int* cpuSubSequenceStartPositions = + subSequenceStartPositions->getMutableData(false); + std::copy(subStarts.begin(), subStarts.end(), cpuSubSequenceStartPositions); + ICpuGpuVector::resizeOrCreate( + sequenceStartPositions, numSequences + 1, false); + int* cpuSequenceStartPositions = + sequenceStartPositions->getMutableData(false); + for (size_t i = 0; i <= numSequences; ++i) { + cpuSequenceStartPositions[i] = subStarts[starts[i]]; + } + copyScattedId(allIds, &info.allIds, allIds.size()); + CHECK_EQ(info.idIndex.size(), static_cast(maxSequenceLength_ + 1)); +} + /* create scattered id infomation for all realLayer of inFrameLines one time. * If hasSubseq, will also create scattered sequenceStartPositions infomation * for all realLayer of inFrameLines one time. */ - void RecurrentGradientMachine::createInFrameInfo(int inlinkId, const Argument& input, PassType passType) { - bool hasSubseq = input.hasSubseq(); - // numSequences: # samples(sequences) in a batch - size_t numSequences = input.getNumSequences(); + if (!input.hasSeq()) { + createInFrameInfo_nonseq(inlinkId, input, passType); + } else if (!input.hasSubseq()) { + createInFrameInfo_seq(inlinkId, input, passType); + } else { + createInFrameInfo_subseq(inlinkId, input, passType); + } +} + +void RecurrentGradientMachine::createInFrameInfo_nonseq(int inlinkId, + const Argument& input, + PassType passType) { std::vector allIds; auto& seqInfo = seqInfos_[inlinkId]; - - numSeqs_.clear(); Info* inlinkInfo = &info_[inlinkId]; inlinkInfo->idIndex.clear(); - inlinkInfo->idIndex.push_back(0); // first idIndex = 0 + for (size_t i = 0; i < seqInfo.size(); ++i) { + allIds.push_back(seqInfo[i].seqId); + } + // copy and check scatterId + copyScattedId(allIds, &inlinkInfo->allIds, input.getBatchSize()); +} +void RecurrentGradientMachine::createInFrameInfo_seq(int inlinkId, + const Argument& input, + PassType passType) { + std::vector allIds; + auto& seqInfo = seqInfos_[inlinkId]; + Info* inlinkInfo = &info_[inlinkId]; + inlinkInfo->idIndex.resize(1, 0); // first idIndex = 0 + + for (int i = 0; i < maxSequenceLength_; ++i) { + for (int j = 0; j < numSeqs_[i]; ++j) { + int seqLength = seqInfo[j].topLevelLength; + int seqStart = seqInfo[j].seqStart; + allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i) + : (seqStart + i)); + } + inlinkInfo->idIndex.push_back(allIds.size()); + } + + // copy and check scatterId + copyScattedId(allIds, &inlinkInfo->allIds, input.getBatchSize()); + CHECK_EQ(inlinkInfo->idIndex.size(), + static_cast(maxSequenceLength_ + 1)); +} +void RecurrentGradientMachine::createInFrameInfo_subseq(int inlinkId, + const Argument& input, + PassType passType) { + std::vector allIds; + + auto& seqInfo = seqInfos_[inlinkId]; + + Info* inlinkInfo = &info_[inlinkId]; + inlinkInfo->idIndex.resize(1, 0); // first idIndex = 0 std::vector sequenceStartPositions; const int* subSequenceStartPositions = nullptr; - if (hasSubseq) { // for sequenceScatterAgentLayer - subSequenceStartPositions = input.subSequenceStartPositions->getData(false); - inlinkInfo->seqStartPosIndex.clear(); - inlinkInfo->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0 - } - // maxSequenceLength_: max topLevelLength in allsamples + subSequenceStartPositions = input.subSequenceStartPositions->getData(false); + inlinkInfo->seqStartPosIndex.clear(); + inlinkInfo->seqStartPosIndex.push_back(0); // first seqStartPosIndex = 0 for (int i = 0; i < maxSequenceLength_; ++i) { - if (hasSubseq) { - sequenceStartPositions.push_back(0); // first element = 0 - } - int numSeqs = 0; - for (size_t j = 0; j < numSequences; ++j) { - int seqLength = seqInfo[j].topLevelLength; - if (i >= seqLength) { - break; - } - ++numSeqs; - if (hasSubseq) { - int subSeqStart = subSequenceStartPositions[seqInfo[j].subSeqStart + i]; - int subSeqEnd = - subSequenceStartPositions[seqInfo[j].subSeqStart + i + 1]; - for (int k = subSeqStart; k < subSeqEnd; ++k) { - allIds.push_back(k); - } - sequenceStartPositions.push_back(sequenceStartPositions.back() + - subSeqEnd - subSeqStart); - } else { - int seqStart = seqInfo[j].seqStart; - allIds.push_back(reversed_ ? (seqStart + seqLength - 1 - i) - : (seqStart + i)); + sequenceStartPositions.push_back(0); // first element = 0 + for (int j = 0; j < numSeqs_[i]; ++j) { + int subSeqStart = subSequenceStartPositions[seqInfo[j].subSeqStart + i]; + int subSeqEnd = subSequenceStartPositions[seqInfo[j].subSeqStart + i + 1]; + for (int k = subSeqStart; k < subSeqEnd; ++k) { + allIds.push_back(k); } + sequenceStartPositions.push_back(sequenceStartPositions.back() + + subSeqEnd - subSeqStart); } inlinkInfo->idIndex.push_back(allIds.size()); - numSeqs_.push_back(numSeqs); - if (hasSubseq) { - inlinkInfo->seqStartPosIndex.push_back(sequenceStartPositions.size()); - } - } - if (hasSubseq) { - // inFrameLine create sequenceStartPositions one time - CHECK_EQ( - sequenceStartPositions.size(), - static_cast(maxSequenceLength_ + input.getNumSubSequences())); - CHECK_EQ(inlinkInfo->seqStartPosIndex.size(), - static_cast(maxSequenceLength_ + 1)); - createSeqPos(sequenceStartPositions, &inlinkInfo->sequenceStartPositions); + inlinkInfo->seqStartPosIndex.push_back(sequenceStartPositions.size()); } + // inFrameLine create sequenceStartPositions one time + CHECK_EQ( + sequenceStartPositions.size(), + static_cast(maxSequenceLength_ + input.getNumSubSequences())); + CHECK_EQ(inlinkInfo->seqStartPosIndex.size(), + static_cast(maxSequenceLength_ + 1)); + createSeqPos(sequenceStartPositions, &inlinkInfo->sequenceStartPositions); // copy and check scatterId copyScattedId(allIds, &inlinkInfo->allIds, input.getBatchSize()); @@ -717,11 +867,11 @@ void RecurrentGradientMachine::createMemoryFrameInfo( const Argument& input = (*memoryFrameLine).rootLayer->getOutput(); size_t numSequences = input.getNumSequences(); std::vector allIds; - bool seqFlag = (*memoryFrameLine).is_sequence; + bool seqFlag = input.hasSeq(); + CHECK(!input.hasSubseq()) + << "Subsequence boot layer for memory is not supported"; if (seqFlag) { // for sequenceScatterAgentLayer - CHECK(input.sequenceStartPositions) - << "boot layer must be a sequence when is_sequence = true"; std::vector sequenceStartPositions; sequenceStartPositions.push_back(0); // first element = 0 const int* starts = input.sequenceStartPositions->getData(false); @@ -804,8 +954,7 @@ size_t RecurrentGradientMachine::getGenBatchSize() { for (auto& memoryFrameLine : memoryFrameLines_) { if (!memoryFrameLine.rootLayer) continue; Argument& bootArg = memoryFrameLine.rootLayer->getOutput(); - size_t batchSize = memoryFrameLine.is_sequence ? bootArg.getNumSequences() - : bootArg.getBatchSize(); + size_t batchSize = bootArg.getNumSequences(); if (numSequences) { CHECK_EQ(numSequences, batchSize); } else { @@ -845,12 +994,7 @@ void RecurrentGradientMachine::generateSequence() { if (memoryFrameLine.rootAgent) { auto scatterAgent = dynamic_cast(memoryFrameLine.rootAgent.get()); - bool seqFlag = memoryFrameLine.is_sequence; - scatterAgent->setRealLayer(memoryFrameLine.rootLayer, ids, seqFlag); - if (seqFlag) { - CHECK(memoryFrameLine.rootLayer->getOutput().sequenceStartPositions) - << "boot layer must be a sequence when is_sequence = true"; - } + scatterAgent->setRealLayer(memoryFrameLine.rootLayer, ids); } NeuralNetwork::connect( memoryFrameLine.agents[0], memoryFrameLine.bootLayer, ids.size()); @@ -858,6 +1002,7 @@ void RecurrentGradientMachine::generateSequence() { // boot layer forward AsyncGpuBlock asyncGpuBlock; + for (auto& memoryFrameLine : memoryFrameLines_) { memoryFrameLine.bootLayer->forward(PASS_TEST); } @@ -930,8 +1075,7 @@ void RecurrentGradientMachine::oneWaySearch(size_t batchSize) { auto scatterAgent = dynamic_cast( memoryFrameLine.scatterAgents[machineCur].get()); scatterAgent->setRealLayer(memoryFrameLine.frames[machinePrev], - scatterIds, - memoryFrameLine.is_sequence); + scatterIds); scatterAgent->forward(PASS_TEST); NeuralNetwork::connect(memoryFrameLine.agents[machineCur], memoryFrameLine.scatterAgents[machineCur]); @@ -1003,8 +1147,7 @@ void RecurrentGradientMachine::connectPrevFrame(int stepId, auto scatterAgent = dynamic_cast( memoryFrameLine.scatterAgents[machineCur].get()); scatterAgent->setRealLayer(memoryFrameLine.frames[machinePrev], - isOutIds ? topIds_ : machineIds_, - memoryFrameLine.is_sequence); + isOutIds ? topIds_ : machineIds_); scatterAgent->forward(PASS_TEST); NeuralNetwork::connect(memoryFrameLine.agents[machineCur], memoryFrameLine.scatterAgents[machineCur]); diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h index c2bc52709ab42bbe21dcc3951f23f2e0b5e6793d..8d94d7e2df216c4657d759c16dd6b1f2848996e0 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.h +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.h @@ -284,6 +284,16 @@ public: } protected: + std::vector commonSeqInfo_; + ICpuGpuVectorPtr sequenceStartPositions_; + void calcSequenceStartPositions(); + void checkInputConsistency(int inlinkId, + const std::vector& seqInfo); + void reorganizeInput(PassType passType); + void reorganizeOutput(PassType passType); + void connectFrames(PassType passType); + void calcNumSequencesAtEachStep(); + void resizeOrCreateFrames(int numFrames); void resizeBootFrame(int numSequences); @@ -295,8 +305,7 @@ protected: std::string linkName; LayerPtr inLayer; std::vector agents; // Scatter Agents to reform batch input - bool hasSubseq; - Argument outArg; // scatter output argument + Argument outArg; // scatter output argument }; std::vector inFrameLines_; @@ -318,7 +327,6 @@ protected: std::vector agents; std::vector scatterAgents; // scatter agent used by beam search Argument outArg; // scatter output argument - bool is_sequence; // Different memoryFrameLine have different element as follows IVectorPtr allIds; // scattered id of realLayer ICpuGpuVectorPtr @@ -330,22 +338,27 @@ protected: // and all outFrameLines(outlinks) share the info with one inFrameLine, // which is assigned by targetInfoInlinkId_. struct Info { - IVectorPtr allIds; // scattered id of realLayer - std::vector idIndex; // index of allIds + // The original positions in the original batch + IVectorPtr allIds; // scattered id of realLayer [batchSize] + + // index of allIds for each step [maxSequenceLength_] + // idIndex[i] is the total length of the first i sequences + std::vector idIndex; + ICpuGpuVectorPtr sequenceStartPositions; // scattered sequenceStartPositions std::vector seqStartPosIndex; // index of sequenceStartPositions }; - std::vector info_; + std::vector info_; // for input // numSeqs_[i] is the number sequences which is longer than i (for sequence // data) or has more than i subsequences (for subsequence data) + // Equivalently, numSeqs_[i] is the number of sequences at step i; std::vector numSeqs_; std::vector> seqInfos_; - // the id of inlink which share info with outlinks - int targetInfoInlinkId_; + void checkOutputConsistency(OutFrameLine& outFrameLine); /* create scattered id infomation for all realLayer of inFrameLines one time. * If hasSubseq, will also create scattered sequenceStartPositions infomation @@ -354,6 +367,28 @@ protected: void createInFrameInfo(int inlinks_id, const Argument& input, PassType passType); + void createInFrameInfo_nonseq(int inlinks_id, + const Argument& input, + PassType passType); + void createInFrameInfo_seq(int inlinks_id, + const Argument& input, + PassType passType); + void createInFrameInfo_subseq(int inlinks_id, + const Argument& input, + PassType passType); + + void createOutFrameInfo(OutFrameLine& outFrameLine, + Info& info, + ICpuGpuVectorPtr& sequenceStartPositions, + ICpuGpuVectorPtr& subSequenceStartPositions); + void createOutFrameInfo_seq(OutFrameLine& outFrameLine, + Info& info, + ICpuGpuVectorPtr& sequenceStartPositions, + ICpuGpuVectorPtr& subSequenceStartPositions); + void createOutFrameInfo_subseq(OutFrameLine& outFrameLine, + Info& info, + ICpuGpuVectorPtr& sequenceStartPositions, + ICpuGpuVectorPtr& subSequenceStartPositions); void createMemoryFrameInfo(MemoryFrameLine* memoryFrameLine, PassType passType); @@ -386,9 +421,7 @@ protected: NeuralNetwork* rootNetwork_; bool reversed_; - // if hasSubseq: max number of sentences(subseq)in batchsize samples - // else: max number of tokens in batchsize samples(sentences) - int maxSequenceLength_; + int maxSequenceLength_; // Max top-level length bool useGpu_; bool stopBeamSearch_; diff --git a/paddle/gserver/layers/AgentLayer.cpp b/paddle/gserver/layers/AgentLayer.cpp index 7b1b99b135e35e5fe41dbb3d053a96e3e31e5cf1..31463823b3fc04cc24068d95887a9d3ed25a6168 100644 --- a/paddle/gserver/layers/AgentLayer.cpp +++ b/paddle/gserver/layers/AgentLayer.cpp @@ -36,14 +36,23 @@ void AgentLayer::forward(PassType passType) { Layer::forward(passType); Argument& realOutput = realLayer_->getOutput(); - int realHeight = realOutput.getBatchSize(); - CHECK_LE(numSamples_, realHeight); + int realNumSequences = realOutput.getNumSequences(); + CHECK_LE(numSamples_, realNumSequences); // get Arguments from real layers - if (numSamples_ > 0 && numSamples_ < realHeight) { - if (realOutput.ids) { - output_.ids = - IVector::create(realOutput.ids->getData(), numSamples_, useGpu_); + if (numSamples_ > 0 && numSamples_ < realNumSequences) { + if (realOutput.hasSeq()) { + int numRows = + realOutput.sequenceStartPositions->getData(false)[numSamples_]; + output_.subArgFrom(realOutput, + /* offset */ 0, + numRows, + getSize(), + useGpu_, + /* trans */ false, + /* seqFlag */ true, + /* seqStart */ 0, + /* seqSize */ numSamples_ + 1); } else { output_.subArgFrom( realOutput, /* offset */ 0, numSamples_, getSize(), useGpu_); @@ -53,34 +62,6 @@ void AgentLayer::forward(PassType passType) { } } -void SequenceAgentLayer::forward(PassType passType) { - Layer::forward(passType); - - Argument& realOutput = realLayer_->getOutput(); - int realNumSequences = realOutput.getNumSequences(); - CHECK_LE(numSamples_, realNumSequences); - - // get Arguments from real layers - if (numSamples_ > 0 && numSamples_ < realNumSequences) { - int numRows = - realOutput.sequenceStartPositions->getData(false)[numSamples_]; - CHECK(!realOutput.ids) << "Not supported"; - output_.subArgFrom(realOutput, - /* offset */ 0, - numRows, - getSize(), - useGpu_, - /* trans */ false, - /* seqFlag */ true, - /* seqStart */ 0, - /* seqSize */ numSamples_ + 1); - } else { - output_ = realOutput; - } -} - -REGISTER_LAYER(sequence_agent, SequenceAgentLayer); - bool GatherAgentLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { CHECK_EQ(config_.inputs_size(), 0); @@ -91,18 +72,26 @@ bool GatherAgentLayer::init(const LayerMap& layerMap, return true; } -void GatherAgentLayer::copyIdAndSequenceInfo(const Argument& input, - const IVectorPtr& ids, - const std::vector& idIndex) { - output_.sequenceStartPositions = input.sequenceStartPositions; - output_.subSequenceStartPositions = input.subSequenceStartPositions; - realLayers_.clear(); +void GatherAgentLayer::copyIdAndSequenceInfo( + ICpuGpuVectorPtr sequenceStartPositions, + ICpuGpuVectorPtr subSequenceStartPositions, + const IVectorPtr& ids, + const std::vector& idIndex) { + output_.sequenceStartPositions = sequenceStartPositions; + output_.subSequenceStartPositions = subSequenceStartPositions; allIds_ = ids; idIndex_ = idIndex; } void GatherAgentLayer::forward(PassType passType) { Layer::forward(passType); + forwardIds(passType); + forwardValue(passType); +} + +void GatherAgentLayer::forwardValue(PassType passType) { + MatrixPtr valueReal = realLayers_[0]->getOutputValue(); + if (!valueReal) return; int height = allIds_->getSize(); int width = this->getSize(); @@ -147,7 +136,9 @@ void ScatterAgentLayer::forward(PassType passType) { CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); int width = this->getSize(); - if (realOutArg_.value || realOutArg_.ids) { + if (realOutArg_.hasSeq()) { + forwardSequence(passType); + } else if (realOutArg_.value || realOutArg_.ids) { output_.subArgFrom( realOutArg_, /* offset */ idIndex_, idSize_, width, useGpu_); } else { // used in generation @@ -174,7 +165,7 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { if (realGrad) { // for agent in inFrameLines and memoryFrameLines, // only first scatterAgentLayer should do addToRows in backward - if (idIndex_ == 0) { + if (handleBackward_) { outputGrad->addToRows(*realGrad, *ids_); } } @@ -183,12 +174,14 @@ void ScatterAgentLayer::backward(const UpdateCallback& callback) { REGISTER_LAYER(gather_agent, GatherAgentLayer); REGISTER_LAYER(scatter_agent, ScatterAgentLayer); -void SequenceGatherAgentLayer::forward(PassType passType) { - Layer::forward(passType); +void GatherAgentLayer::forwardIds(PassType passType) { int height = 0; - int* starts = output_.subSequenceStartPositions->getMutableData(false); IVectorPtr idReal = realLayers_[0]->getOutputLabel(); - if (idReal) { + + if (!idReal) return; + + if (output_.subSequenceStartPositions) { + int* starts = output_.subSequenceStartPositions->getMutableData(false); // Gather generator.idsVec // if is beam search generation result. Get first result. if (idReal->getData()[idReal->getSize() - 1] == -1) { @@ -212,13 +205,11 @@ void SequenceGatherAgentLayer::forward(PassType passType) { ->copyFrom(*realLayers_[i]->getOutputLabel()); } } else { - // Gather output.value, same as GatherAgentLayer - CHECK(output_.subSequenceStartPositions); - GatherAgentLayer::forward(passType); + LOG(FATAL) << "Not implemented"; } } -void SequenceScatterAgentLayer::forward(PassType passType) { +void ScatterAgentLayer::forwardSequence(PassType passType) { Layer::forward(passType); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); @@ -241,6 +232,7 @@ void SequenceScatterAgentLayer::forward(PassType passType) { /* seqStart */ seqStartPosIndex_, /* seqSize */ numSequences_); } else { + // Putting the generation logic here is really an ugly hack! // used in generation int height = 0; size_t numSequences = ids_->getSize(); @@ -284,7 +276,4 @@ void SequenceScatterAgentLayer::forward(PassType passType) { } } -REGISTER_LAYER(sequence_gather_agent, SequenceGatherAgentLayer); -REGISTER_LAYER(sequence_scatter_agent, SequenceScatterAgentLayer); - } // namespace paddle diff --git a/paddle/gserver/layers/AgentLayer.h b/paddle/gserver/layers/AgentLayer.h index b6dac7ae6fec2d61c60c9548d466233efe9febd5..461b84b17e556b53e0734bff8e37a0d529a3290e 100644 --- a/paddle/gserver/layers/AgentLayer.h +++ b/paddle/gserver/layers/AgentLayer.h @@ -49,18 +49,6 @@ public: void backward(const UpdateCallback& callback = nullptr) override {} }; -/** - * like AgentLayer, but use first *numSamples* sequences - */ -class SequenceAgentLayer : public AgentLayer { -public: - explicit SequenceAgentLayer(const LayerConfig& config) : AgentLayer(config) {} - ~SequenceAgentLayer() {} - - void forward(PassType passType) override; - void backward(const UpdateCallback& callback = nullptr) override {} -}; - /** * Like AgentLayer, but it can gather many real layers. Each real * layer give a few rows of a sequence, after gather all real layers, @@ -83,7 +71,10 @@ public: const ParameterMap& parameterMap) override; // call before addRealLayer - void copyIdAndSequenceInfo(const Argument& input, + void clearRealLayers() { realLayers_.clear(); } + + void copyIdAndSequenceInfo(ICpuGpuVectorPtr sequenceStartPositions, + ICpuGpuVectorPtr subSequenceStartPositions, const IVectorPtr& allIds, const std::vector& idIndex); @@ -92,24 +83,8 @@ public: void forward(PassType passType) override; void backward(const UpdateCallback& callback) override; -}; - -/** - * Like GatherAgentLayer, but select a few sequence in real layer. - * *ids* in addRealLayer() are the ids of selected sequence. - * It's used to reorder sequence output. - */ -class SequenceGatherAgentLayer : public GatherAgentLayer { -public: - explicit SequenceGatherAgentLayer(const LayerConfig& config) - : GatherAgentLayer(config) {} - virtual ~SequenceGatherAgentLayer() {} - - void forward(PassType passType); - void backward(const UpdateCallback& callback) { - // same as GatherAgentLayer - GatherAgentLayer::backward(callback); - } + void forwardValue(PassType passType); + void forwardIds(PassType passType); }; /** @@ -129,6 +104,11 @@ protected: int idSize_; int seqStartPosIndex_; int numSequences_; // number of sequences in this scatterAgentLayer + bool handleBackward_; + + // use to store expanded cpuStartPositions or subSequenceStartPositions + // of real layer. + ICpuGpuVectorPtr inputStartPos_; public: explicit ScatterAgentLayer(const LayerConfig& config) : Layer(config) {} @@ -147,19 +127,15 @@ public: * false(default) in ScatterAgentLayer, and * true in SequenceScatterAgentLayer. */ - void setRealLayer(LayerPtr layer, - const std::vector& ids, - bool copyId = false) { + void setRealLayer(LayerPtr layer, const std::vector& ids) { realLayer_ = layer; IVector::resizeOrCreate(ids_, ids.size(), useGpu_); ids_->copyFrom(ids.data(), ids.size()); - if (copyId) { - if (useGpu_) { - IVector::resizeOrCreate(cpuIds_, ids.size(), false); - cpuIds_->copyFrom(ids.data(), ids.size()); - } else { - cpuIds_ = ids_; - } + if (useGpu_) { + IVector::resizeOrCreate(cpuIds_, ids.size(), false); + cpuIds_->copyFrom(ids.data(), ids.size()); + } else { + cpuIds_ = ids_; } } @@ -169,12 +145,14 @@ public: const Argument& outArg, const IVectorPtr& ids, int idIndex, - int idSize) { + int idSize, + bool handleBackward) { realLayer_ = layer; realOutArg_ = outArg; ids_ = ids; idIndex_ = idIndex; idSize_ = idSize; + handleBackward_ = handleBackward; } void setSequenceStartPositions(const ICpuGpuVectorPtr& sequenceStartPositions, @@ -187,28 +165,8 @@ public: void forward(PassType passType) override; void backward(const UpdateCallback& callback) override; -}; -/** - * Like ScatterAgentLayer, but select a few sequence in real layer. - * *ids* in setRealLayer() or setRealLayerAndOutput() are the ids of - * selected sequence. It's used to reorder sequence input. - */ -class SequenceScatterAgentLayer : public ScatterAgentLayer { -protected: - // use to store expanded cpuStartPositions or subSequenceStartPositions - // of real layer. - ICpuGpuVectorPtr inputStartPos_; - -public: - explicit SequenceScatterAgentLayer(const LayerConfig& config) - : ScatterAgentLayer(config) {} - virtual ~SequenceScatterAgentLayer() {} - - void forward(PassType passType); - void backward(const UpdateCallback& callback) { - ScatterAgentLayer::backward(callback); - } + void forwardSequence(PassType passType); }; } // namespace paddle diff --git a/paddle/gserver/layers/FeatureMapExpandLayer.cpp b/paddle/gserver/layers/FeatureMapExpandLayer.cpp index b3850f543af74abbddaac5bb0a32851f2d3297d0..8a2ae6b49fcc13ed22eca2a33c8296827812bff9 100644 --- a/paddle/gserver/layers/FeatureMapExpandLayer.cpp +++ b/paddle/gserver/layers/FeatureMapExpandLayer.cpp @@ -40,6 +40,7 @@ namespace paddle { class FeatureMapExpandLayer : public Layer { private: int numFilters_; + bool asRowVector_; public: explicit FeatureMapExpandLayer(const LayerConfig& config) : Layer(config) {} @@ -62,6 +63,7 @@ bool FeatureMapExpandLayer::init(const LayerMap& layerMap, CHECK_EQ(inputLayers_.size(), 1UL); numFilters_ = config_.num_filters(); + asRowVector_ = config_.user_arg() != "as_col_vec"; return true; } @@ -76,16 +78,30 @@ void FeatureMapExpandLayer::forward(PassType passType) { { AsyncGpuBlock asyncGpuBlock; - for (size_t i = 0; i < batchSize; i++) { - MatrixPtr outVTmp = - Matrix::create(outputV->getData() + i * imgSize * numFilters_, - numFilters_, - imgSize, - false, - useGpu_); - MatrixPtr inVTmp = Matrix::create( - inputV->getData() + i * imgSize, 1, imgSize, false, useGpu_); - outVTmp->addRowVector(*inVTmp); + if (asRowVector_) { + for (size_t i = 0; i < batchSize; i++) { + MatrixPtr outVTmp = + Matrix::create(outputV->getData() + i * imgSize * numFilters_, + numFilters_, + imgSize, + false, + useGpu_); + MatrixPtr inVTmp = Matrix::create( + inputV->getData() + i * imgSize, 1, imgSize, false, useGpu_); + outVTmp->addRowVector(*inVTmp); + } + } else { + for (size_t i = 0; i < batchSize; i++) { + MatrixPtr outVTmp = + Matrix::create(outputV->getData() + i * imgSize * numFilters_, + imgSize, + numFilters_, + false, + useGpu_); + MatrixPtr inVTmp = Matrix::create( + inputV->getData() + i * imgSize, imgSize, 1, false, useGpu_); + outVTmp->addColVector(*inVTmp); + } } } /* activation */ { @@ -102,24 +118,38 @@ void FeatureMapExpandLayer::backward(const UpdateCallback& callback) { MatrixPtr outGrad = getOutputGrad(); size_t batchSize = getInput(0).getBatchSize(); int imgSize = inGrad->getWidth(); + /* Do activation */ { + REGISTER_TIMER_INFO("BpAvtTimer", getName().c_str()); + backwardActivation(); + } { AsyncGpuBlock asyncGpuBlock; - for (size_t i = 0; i < batchSize; i++) { - MatrixPtr outGradTmp = - Matrix::create(outGrad->getData() + i * imgSize * numFilters_, - numFilters_, - imgSize, - false, - useGpu_); - MatrixPtr inGradTmp = Matrix::create( - inGrad->getData() + i * imgSize, 1, imgSize, false, useGpu_); - inGradTmp->collectBias(*outGradTmp, 1); + if (asRowVector_) { + for (size_t i = 0; i < batchSize; i++) { + MatrixPtr outGradTmp = + Matrix::create(outGrad->getData() + i * imgSize * numFilters_, + numFilters_, + imgSize, + false, + useGpu_); + MatrixPtr inGradTmp = Matrix::create( + inGrad->getData() + i * imgSize, 1, imgSize, false, useGpu_); + inGradTmp->collectBias(*outGradTmp, 1); + } + } else { + for (size_t i = 0; i < batchSize; i++) { + MatrixPtr outGradTmp = + Matrix::create(outGrad->getData() + i * imgSize * numFilters_, + imgSize, + numFilters_, + false, + useGpu_); + MatrixPtr inGradTmp = Matrix::create( + inGrad->getData() + i * imgSize, imgSize, 1, false, useGpu_); + inGradTmp->sumRows(*outGradTmp, 1, 1); + } } } - /* Do derivation */ { - REGISTER_TIMER_INFO("BpAvtTimer", getName().c_str()); - backwardActivation(); - } } } // namespace paddle. diff --git a/paddle/gserver/layers/PrintLayer.cpp b/paddle/gserver/layers/PrintLayer.cpp index de198af111be4200dd1b240f6de9464e3f43b06d..a97fa6bf78fce27a4e0cf329bf3309ba4a439965 100644 --- a/paddle/gserver/layers/PrintLayer.cpp +++ b/paddle/gserver/layers/PrintLayer.cpp @@ -22,10 +22,33 @@ public: void forward(PassType passType) override { Layer::forward(passType); + std::vector vals; for (size_t i = 0; i != inputLayers_.size(); ++i) { - getInput(i).printValueString(LOG(INFO), - "layer=" + inputLayers_[i]->getName() + " "); + std::ostringstream s; + getInput(i).printValueString(s, ""); + vals.push_back(s.str()); } + size_t pos = 0; + int i = 0; + std::ostringstream s; + const std::string& format = config_.user_arg(); + while (true) { + size_t pos1 = format.find("%s", pos); + if (pos1 == std::string::npos) break; + if (i >= vals.size()) { + break; + } + s << format.substr(pos, pos1 - pos) << vals[i]; + pos = pos1 + 2; + ++i; + } + if (i != inputLayers_.size()) { + LOG(ERROR) << "Number of value in the format (" << format + << ") is not same as the number of inputs (" + << inputLayers_.size() << ") at " << getName(); + } + s << format.substr(pos); + LOG(INFO) << s.str(); } void backward(const UpdateCallback& callback) override {} diff --git a/paddle/gserver/layers/SequencePoolLayer.cpp b/paddle/gserver/layers/SequencePoolLayer.cpp index 235d9a9b0f0653df5c0b671092df9e195f08fc48..4179a9e7e0cb58fcb49bff712e62b9f3fea373bd 100644 --- a/paddle/gserver/layers/SequencePoolLayer.cpp +++ b/paddle/gserver/layers/SequencePoolLayer.cpp @@ -46,6 +46,9 @@ void SequencePoolLayer::forward(PassType passType) { Layer::forward(passType); const Argument& input = getInput(0); + CHECK(input.hasSeq() || input.hasSubseq()) + << "Input should be a sequence or subsequence for layer " << getName(); + newBatchSize_ = type_ ? input.getNumSubSequences() : input.getNumSequences(); size_t dim = getSize(); // check diff --git a/paddle/gserver/tests/rnn_data_provider.py b/paddle/gserver/tests/rnn_data_provider.py index 3afd45c72f4dd071ddca569caac8716fe102299b..913365a5a4037d14fcba1e1546508ba89668e0d6 100644 --- a/paddle/gserver/tests/rnn_data_provider.py +++ b/paddle/gserver/tests/rnn_data_provider.py @@ -95,3 +95,22 @@ def process_unequalength_seq(settings, file_name): words1 = reduce(lambda x, y: x + y, d[0]) words2 = reduce(lambda x, y: x + y, d[1]) yield words1, words2, d[2] + + +########################################################### +data3 = [ + [[[1, 2], [4, 5, 2]], [1, 2], 0], + [[[0, 2], [2, 5], [0, 1, 2]], [2, 3, 0], 1], +] + + +# Used for sequence_nest_mixed_inputs.conf +@provider( + input_types=[ + integer_value_sub_sequence(10), integer_value_sequence(10), + integer_value(2) + ], + should_shuffle=False) +def process_mixed(settings, file_name): + for d in data3: + yield d diff --git a/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf b/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf index ad14a2c927c89c9b480af5ad565c37e8b2e54469..afdacfffd7aecfe2f4762f04a987126381bcea34 100644 --- a/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf +++ b/paddle/gserver/tests/sequence_nest_rnn_multi_input.conf @@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import * define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list', test_list=None, module='rnn_data_provider', - obj='process_subseq2') + obj='process_subseq') settings(batch_size=2, learning_rate=0.01) @@ -57,7 +57,7 @@ def outer_step(wid, x): last = last_seq(input=inner_rnn_output, name="outer_rnn_state") # "return last" should also work. But currently RecurrentGradientMachine - # does not handle it, and will report error: In hierachical RNN, all out + # does not handle it, and will report error: In hierachical RNN, all out # links should be from sequences now. return inner_rnn_output diff --git a/paddle/gserver/tests/sequence_rnn_matched_inputs.py b/paddle/gserver/tests/sequence_rnn_matched_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..e2635b4400b13517bac716a5a0affeb16c218b09 --- /dev/null +++ b/paddle/gserver/tests/sequence_rnn_matched_inputs.py @@ -0,0 +1,85 @@ +# edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +######################## data source ################################ +define_py_data_sources2( + train_list='gserver/tests/Sequence/dummy.list', + test_list=None, + module='rnn_data_provider', + obj='process_mixed') + +settings(batch_size=2, learning_rate=0.01) +######################## network configure ################################ +dict_dim = 10 +word_dim = 2 +hidden_dim = 2 +label_dim = 2 + +data1 = data_layer(name="word1", size=dict_dim) +data2 = data_layer(name="word2", size=dict_dim) +label = data_layer(name="label", size=label_dim) + +encoding = embedding_layer(input=data2, size=word_dim) + +subseq = embedding_layer(input=data1, size=word_dim) +seq = embedding_layer(input=data2, size=word_dim) +nonseq = embedding_layer(input=label, size=word_dim) + + +# This hierarchical RNN is designed to be equivalent to the simple RNN in +# sequence_rnn_multi_unequalength_inputs.conf +def outer_step(subseq, seq, nonseq, encoding): + outer_mem = memory(name="outer_rnn_state", size=hidden_dim) + + def inner_step(subseq, seq, nonseq): + inner_mem = memory( + name="inner_rnn_state", size=hidden_dim, boot_layer=outer_mem) + + out = fc_layer( + input=[subseq, seq, nonseq, inner_mem], + size=hidden_dim, + act=TanhActivation(), + bias_attr=True, + name='inner_rnn_state') + return out + + decoder = recurrent_group( + step=inner_step, name='inner', input=[subseq, seq, nonseq]) + last = last_seq(name="outer_rnn_state", input=decoder) + context = simple_attention( + encoded_sequence=encoding, encoded_proj=encoding, decoder_state=last) + return context + + +out = recurrent_group( + name="outer", + step=outer_step, + input=[ + subseq, expand_layer( + seq, expand_as=subseq, + expand_level=ExpandLevel.FROM_SEQUENCE), expand_layer( + nonseq, + expand_as=subseq, + expand_level=ExpandLevel.FROM_NO_SEQUENCE), + StaticInput(encoding) + ]) + +rep = last_seq(input=out) +prob = fc_layer( + size=label_dim, input=rep, act=SoftmaxActivation(), bias_attr=True) + +outputs(classification_cost(input=prob, label=label)) diff --git a/paddle/gserver/tests/sequence_rnn_mixed_inputs.py b/paddle/gserver/tests/sequence_rnn_mixed_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..84a66e294495c01e03dc83b38a531e482bed1292 --- /dev/null +++ b/paddle/gserver/tests/sequence_rnn_mixed_inputs.py @@ -0,0 +1,79 @@ +# edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +######################## data source ################################ +define_py_data_sources2( + train_list='gserver/tests/Sequence/dummy.list', + test_list=None, + module='rnn_data_provider', + obj='process_mixed') + +settings(batch_size=2, learning_rate=0.01) +######################## network configure ################################ +dict_dim = 10 +word_dim = 2 +hidden_dim = 2 +label_dim = 2 + +data1 = data_layer(name="word1", size=dict_dim) +data2 = data_layer(name="word2", size=dict_dim) +label = data_layer(name="label", size=label_dim) + +encoding = embedding_layer(input=data2, size=word_dim) + + +# This hierarchical RNN is designed to be equivalent to the simple RNN in +# sequence_rnn_multi_unequalength_inputs.conf +def outer_step(subseq, seq, nonseq, encoding): + outer_mem = memory(name="outer_rnn_state", size=hidden_dim) + + def inner_step(data1, data2, label): + inner_mem = memory( + name="inner_rnn_state", size=hidden_dim, boot_layer=outer_mem) + + subseq = embedding_layer(input=data1, size=word_dim) + seq = embedding_layer(input=data2, size=word_dim) + nonseq = embedding_layer(input=label, size=word_dim) + + print_layer(input=[data1, seq, label, inner_mem]) + out = fc_layer( + input=[subseq, seq, nonseq, inner_mem], + size=hidden_dim, + act=TanhActivation(), + bias_attr=True, + name='inner_rnn_state') + return out + + decoder = recurrent_group( + step=inner_step, name='inner', + input=[subseq, StaticInput(seq), nonseq]) + last = last_seq(name="outer_rnn_state", input=decoder) + context = simple_attention( + encoded_sequence=encoding, encoded_proj=encoding, decoder_state=last) + return context + + +out = recurrent_group( + name="outer", + step=outer_step, + input=[data1, data2, StaticInput(label), StaticInput(encoding)]) + +rep = last_seq(input=out) +prob = fc_layer( + size=label_dim, input=rep, act=SoftmaxActivation(), bias_attr=True) + +outputs(classification_cost(input=prob, label=label)) diff --git a/paddle/gserver/tests/sequence_rnn_multi_input.conf b/paddle/gserver/tests/sequence_rnn_multi_input.conf index 40d031741573251aa94d2a0f355470c53c51de7e..9fae974f3079c49ad03d6ba34e30190f325414e8 100644 --- a/paddle/gserver/tests/sequence_rnn_multi_input.conf +++ b/paddle/gserver/tests/sequence_rnn_multi_input.conf @@ -19,7 +19,7 @@ from paddle.trainer_config_helpers import * define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list', test_list=None, module='rnn_data_provider', - obj='process_seq2') + obj='process_seq') settings(batch_size=2, learning_rate=0.01) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 6adffcf53b7966bd6f3d02970e5f07cc9802f469..297756025bcad79d49ec321414ed2e91f1c0758a 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -1598,12 +1598,15 @@ TEST(Layer, FeatureMapExpandLayer) { /* paraSize= */ 0}); config.layerConfig.add_inputs(); for (auto useGpu : {false, true}) { - testLayerGrad(config, - "featmap_expand", - /*batch_size*/ 100, - /* trans= */ false, - useGpu, - /* useWeight */ true); + for (auto asRowVec : {false, true}) { + config.layerConfig.set_user_arg(asRowVec ? "as_row_vec" : "as_col_vec"); + testLayerGrad(config, + "featmap_expand", + /*batch_size*/ 100, + /* trans= */ false, + useGpu, + /* useWeight */ true); + } } } diff --git a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp index 4a846397e6cf3100f948af46874b0739e32bf4a5..6b19eb0ce520a625ac68582d5c1e11c168127dc7 100644 --- a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp +++ b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp @@ -155,6 +155,15 @@ TEST(RecurrentGradientMachine, rnn_multi_unequalength_input) { } } +TEST(RecurrentGradientMachine, rnn_mixed_input) { + for (bool useGpu : {false, true}) { + test("gserver/tests/sequence_rnn_mixed_inputs.py", + "gserver/tests/sequence_rnn_matched_inputs.py", + 1e-6, + useGpu); + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/paddle/majel/.gitignore b/paddle/majel/.gitignore deleted file mode 100644 index 1f5acdebb56971202b63d2485e2ac5042786f13c..0000000000000000000000000000000000000000 --- a/paddle/majel/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -build -third-party \ No newline at end of file diff --git a/paddle/majel/detail/cuda_assert.h b/paddle/majel/detail/cuda_assert.h deleted file mode 100644 index 9490d0ae3eff01bdb4403de710b7bfd878e87f03..0000000000000000000000000000000000000000 --- a/paddle/majel/detail/cuda_assert.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once - -#define STRINGIFY(x) #x -#define TOSTRING(x) STRINGIFY(x) - -#if defined(__APPLE__) && defined(__CUDA_ARCH__) && !defined(NDEBUG) -#include -#define MAJEL_ASSERT(e) \ - do { \ - if (!(e)) { \ - printf( \ - "%s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, TOSTRING(e)); \ - asm("trap;"); \ - } \ - } while (0) - -#define MAJEL_ASSERT_MSG(e, m) \ - do { \ - if (!(e)) { \ - printf("%s:%d Assertion `%s` failed (%s).\n", \ - __FILE__, \ - __LINE__, \ - TOSTRING(e), \ - m); \ - asm("trap;"); \ - } \ - } while (0) -#else -#include -#define MAJEL_ASSERT(e) assert(e) -#define MAJEL_ASSERT_MSG(e, m) assert((e) && (m)) -#endif diff --git a/paddle/majel/dim_test.cu b/paddle/majel/dim_test.cu deleted file mode 100644 index a7d81e595bea7fa6326ea350e2702e1ef8f5caa4..0000000000000000000000000000000000000000 --- a/paddle/majel/dim_test.cu +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include - -#include "paddle/majel/dim.h" -#include "gtest/gtest.h" - -__global__ void test(majel::Dim<2>* o) { - o[0] = majel::make_dim(5, 6); -} - -__global__ void dyn_idx_gpu(int* o) { - auto d = majel::make_dim(5, 6); - o[0] = d[1]; -} - -TEST(Dim, Equality) { - // construct a Dim on the CPU - auto a = majel::make_dim(3, 4); - EXPECT_EQ(majel::get<0>(a), 3); - EXPECT_EQ(majel::get<1>(a), 4); - - // construct a Dim on the GPU - thrust::device_vector> t(2); - test<<<1,1>>>(thrust::raw_pointer_cast(t.data())); - a = t[0]; - EXPECT_EQ(majel::get<0>(a), 5); - EXPECT_EQ(majel::get<1>(a), 6); - - // linearization - auto b = majel::make_dim(7, 8); - EXPECT_EQ(majel::linearize(a, b), 83); - - // product - EXPECT_EQ(majel::product(a), 30); - - // mutate a Dim - majel::get<1>(b) = 10; - EXPECT_EQ(majel::get<0>(b), 7); - EXPECT_EQ(majel::get<1>(b), 10); - - // dynamic access - majel::get(b, 0) = 8; - b[1] = 11; - EXPECT_EQ(majel::get<0>(b), 8); - EXPECT_EQ(majel::get<1>(b), 11); - EXPECT_EQ(majel::get(b, 0), 8); - EXPECT_EQ(b[1], 11); - - // dynamic access on GPU - thrust::device_vector r(1); - dyn_idx_gpu<<<1,1>>>(thrust::raw_pointer_cast(r.data())); - int res = r[0]; - EXPECT_EQ(res, 6); - - // ex_prefix_mul - majel::Dim<3> c = majel::ex_prefix_mul(majel::Dim<3>(3, 4, 5)); - EXPECT_EQ(majel::get<0>(c), 1); - EXPECT_EQ(majel::get<1>(c), 3); - EXPECT_EQ(majel::get<2>(c), 12); - - // contiguous_strides - c = majel::contiguous_strides(majel::Dim<3>(10, 1, 10)); - EXPECT_EQ(majel::get<0>(c), 1); - EXPECT_EQ(majel::get<1>(c), 0); - EXPECT_EQ(majel::get<2>(c), 10); - c = majel::contiguous_strides(majel::Dim<3>(10, 10, 1)); - EXPECT_EQ(majel::get<0>(c), 1); - EXPECT_EQ(majel::get<1>(c), 10); - EXPECT_EQ(majel::get<2>(c), 0); - c = majel::contiguous_strides(majel::Dim<3>(1, 10, 10)); - EXPECT_EQ(majel::get<0>(c), 0); - EXPECT_EQ(majel::get<1>(c), 1); - EXPECT_EQ(majel::get<2>(c), 10); - c = majel::contiguous_strides(majel::Dim<3>(2, 3, 4)); - EXPECT_EQ(majel::get<0>(c), 1); - EXPECT_EQ(majel::get<1>(c), 2); - EXPECT_EQ(majel::get<2>(c), 6); - - // generate from an index - auto size = majel::make_dim(4, 5, 2); - c = majel::Dim<3>(14, size); - EXPECT_EQ(majel::get<0>(c), 2); - EXPECT_EQ(majel::get<1>(c), 3); - EXPECT_EQ(majel::get<2>(c), 0); - c = majel::Dim<3>(25, size); - EXPECT_EQ(majel::get<0>(c), 1); - EXPECT_EQ(majel::get<1>(c), 1); - EXPECT_EQ(majel::get<2>(c), 1); -} - -TEST(Dim, Bool) { - auto a = majel::make_dim(3, 4); - auto b = majel::make_dim(5, 6); - auto c = majel::make_dim(3, 4); - - // in_bounds check - EXPECT_TRUE(majel::contained(a, b)); - EXPECT_FALSE(majel::contained(b, a)); - - // comparison - EXPECT_TRUE(a == a); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a == c); - - // contiguous check - int x = 4, y = 5, z = 2; - majel::Dim<3> sizef(x, y, z); - majel::Dim<3> stridea(1, x, x*y); - majel::Dim<3> strideb(2, 2*x, 2*x*y); - majel::Dim<3> stridec(1, x, 2*x*y); - EXPECT_TRUE(majel::contiguous(sizef, stridea)); - EXPECT_FALSE(majel::contiguous(sizef, strideb)); - EXPECT_FALSE(majel::contiguous(sizef, stridec)); -} - -TEST(Dim, Print) { - { - std::stringstream ss; - auto a = majel::make_dim(2, 3); - ss << a; - EXPECT_EQ(ss.str(), "2, 3"); - } - { - std::stringstream ss; - ss << majel::make_dim(8); - EXPECT_EQ(ss.str(), "8"); - } -} diff --git a/paddle/majel/place.cc b/paddle/majel/place.cc deleted file mode 100644 index ca50b37843e0ba047f8f8b8d24a3d3c131587382..0000000000000000000000000000000000000000 --- a/paddle/majel/place.cc +++ /dev/null @@ -1,49 +0,0 @@ -#include "paddle/majel/place.h" - -namespace majel { - -namespace detail { - -class PlacePrinter : public boost::static_visitor<> { -private: - std::ostream& os_; - -public: - PlacePrinter(std::ostream& os) : os_(os) {} - - void operator()(const CpuPlace&) { os_ << "CpuPlace"; } - - void operator()(const GpuPlace& p) { os_ << "GpuPlace(" << p.device << ")"; } -}; - -} // namespace detail - -static Place the_default_place; - -void set_place(const Place& place) { the_default_place = place; } - -const Place& get_place() { return the_default_place; } - -const GpuPlace default_gpu() { return GpuPlace(0); } - -const CpuPlace default_cpu() { return CpuPlace(); } - -bool is_gpu_place(const Place& p) { - return boost::apply_visitor(IsGpuPlace(), p); -} - -bool is_cpu_place(const Place& p) { - return !boost::apply_visitor(IsGpuPlace(), p); -} - -bool places_are_same_class(const Place& p1, const Place& p2) { - return is_gpu_place(p1) == is_gpu_place(p2); -} - -std::ostream& operator<<(std::ostream& os, const majel::Place& p) { - majel::detail::PlacePrinter printer(os); - boost::apply_visitor(printer, p); - return os; -} - -} // namespace majel diff --git a/paddle/majel/place.h b/paddle/majel/place.h deleted file mode 100644 index ad3dc3fe0b80ac5dc10a59910c580d7912469cd4..0000000000000000000000000000000000000000 --- a/paddle/majel/place.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once -#include -#include - -namespace majel { - -struct CpuPlace { - CpuPlace() {} // WORKAROUND: for some reason, omitting this constructor - // causes errors with boost 1.59 and OSX - // needed for variant equality comparison - inline bool operator==(const CpuPlace&) const { return true; } - - inline bool operator!=(const CpuPlace&) const { return false; } -}; - -struct GpuPlace { - GpuPlace(int d) : device(d) {} - - // needed for variant equality comparison - inline bool operator==(const GpuPlace& o) const { return device == o.device; } - - inline bool operator!=(const GpuPlace& o) const { return !(*this == o); } - - GpuPlace() : GpuPlace(0) {} - int device; -}; - -class IsGpuPlace : public boost::static_visitor { -public: - bool operator()(const CpuPlace&) const { return false; } - - bool operator()(const GpuPlace& gpu) const { return true; } -}; - -typedef boost::variant Place; - -void set_place(const Place&); - -const Place& get_place(); - -const GpuPlace default_gpu(); -const CpuPlace default_cpu(); - -bool is_gpu_place(const Place&); -bool is_cpu_place(const Place&); -bool places_are_same_class(const Place&, const Place&); - -std::ostream& operator<<(std::ostream&, const majel::Place&); - -} // namespace majel diff --git a/paddle/majel/place_test.cc b/paddle/majel/place_test.cc deleted file mode 100644 index 6a099ae6b6e4f63a6ce845ab17eaab6e12c2c0b0..0000000000000000000000000000000000000000 --- a/paddle/majel/place_test.cc +++ /dev/null @@ -1,40 +0,0 @@ -#include "paddle/majel/place.h" -#include -#include "gtest/gtest.h" - -TEST(Place, Equality) { - majel::CpuPlace cpu; - majel::GpuPlace g0(0), g1(1), gg0(0); - - EXPECT_EQ(cpu, cpu); - EXPECT_EQ(g0, g0); - EXPECT_EQ(g1, g1); - EXPECT_EQ(g0, gg0); - - EXPECT_NE(g0, g1); - - EXPECT_TRUE(majel::places_are_same_class(g0, gg0)); - EXPECT_FALSE(majel::places_are_same_class(g0, cpu)); -} - -TEST(Place, Default) { - EXPECT_TRUE(majel::is_gpu_place(majel::get_place())); - EXPECT_TRUE(majel::is_gpu_place(majel::default_gpu())); - EXPECT_TRUE(majel::is_cpu_place(majel::default_cpu())); - - majel::set_place(majel::CpuPlace()); - EXPECT_TRUE(majel::is_cpu_place(majel::get_place())); -} - -TEST(Place, Print) { - { - std::stringstream ss; - ss << majel::GpuPlace(1); - EXPECT_EQ("GpuPlace(1)", ss.str()); - } - { - std::stringstream ss; - ss << majel::CpuPlace(); - EXPECT_EQ("CpuPlace", ss.str()); - } -} diff --git a/paddle/math/Vector.cpp b/paddle/math/Vector.cpp index eaa1cdce305c2f9d7a517e9e8c8606dc1f70780b..c519ca500afb1dbfdff6e8d211786f4e18ccf1fd 100644 --- a/paddle/math/Vector.cpp +++ b/paddle/math/Vector.cpp @@ -908,12 +908,13 @@ const T* CpuGpuVectorT::getData(bool useGpu) const { // Operation will change data and need to reset sync_ & syncFlag_. #define MUTABLE_VECTOR_OP(OP, useGpu, args...) \ do { \ - setSync(useGpu); \ if (useGpu) { \ copyToGpu(); \ + setSync(useGpu); \ return gpuVectorT_->OP(args); \ } else { \ copyToCpu(); \ + setSync(useGpu); \ return cpuVectorT_->OP(args); \ } \ } while (0) @@ -1030,7 +1031,7 @@ void CpuGpuVectorT::copyToCpu() { case DATA_AT_GPU: CHECK(gpuVectorT_); this->resizeOrCreate(gpuVectorT_->getSize(), false); - cpuVectorT_->copyFrom(*gpuVectorT_, HPPL_STREAM_DEFAULT); + cpuVectorT_->copyFrom(*gpuVectorT_); setSync(SYNCED); break; case DATA_AT_CPU: @@ -1049,7 +1050,7 @@ void CpuGpuVectorT::copyToGpu() { case DATA_AT_CPU: CHECK(cpuVectorT_); this->resizeOrCreate(cpuVectorT_->getSize(), true); - gpuVectorT_->copyFrom(*cpuVectorT_, HPPL_STREAM_DEFAULT); + gpuVectorT_->copyFrom(*cpuVectorT_); setSync(SYNCED); break; case DATA_AT_GPU: diff --git a/paddle/optimizer/CMakeLists.txt b/paddle/optimizer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4536f62ec7c2c3423d91e309dee993d4212160fe --- /dev/null +++ b/paddle/optimizer/CMakeLists.txt @@ -0,0 +1,18 @@ +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +set(OPITMIZER_SRCS + adadelta_optimizer.cc + adagrad_optimizer.cc + adam_optimizer.cc + optimizer.cc + parameter_optimizer.cc + sgd_optimizer.cc + ) + +add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) +add_dependencies(paddle_optimizer gen_proto_cpp) + +if(WITH_TESTING) + add_simple_unittest(serialization_test) + add_simple_unittest(parameter_optimizer_test) +endif() diff --git a/paddle/optimizer/adadelta_optimizer.cc b/paddle/optimizer/adadelta_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..465ad5e0d2089121a0f11ab916afe0420cbcfab7 --- /dev/null +++ b/paddle/optimizer/adadelta_optimizer.cc @@ -0,0 +1,55 @@ +#include "adadelta_optimizer.h" +#include +#include + +namespace paddle { +namespace optimizer { + +void AdadeltaOptimizer::Update(const Tensor* gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + Tensor& param = *parameter_; + const Tensor& grad = *gradient; + Tensor& accum_g = *accum_gradient_; + Tensor& accum_d = *accum_delta_; + Tensor& update_d = *update_delta_; + for (size_t i = 0; i < param.size(); ++i) { + accum_g[i] = rho_ * accum_g[i] + (1.0 - rho_) * grad[i] * grad[i]; + + update_d[i] = std::sqrt(accum_d[i] + epsilon_) / + std::sqrt(accum_g[i] + epsilon_) * grad[i]; + + accum_d[i] = rho_ * accum_d[i] + (1.0 - rho_) * update_d[i] * update_d[i]; + + param[i] -= learning_rate * update_d[i] + learning_rate * decay_ * param[i]; + } +} + +const char* AdadeltaOptimizer::SerializeState(int* state_len) { + AdadeltaOptimizerState state; + // TODO(zhihong) : add lr_policy serialization + state.set_num_sample_passed(num_sample_passed_); + + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); + TensorToProto(*accum_delta_, state.mutable_accum_delta()); + TensorToProto(*update_delta_, state.mutable_update_delta()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void AdadeltaOptimizer::DeserializeState(const std::string& str) { + AdadeltaOptimizerState state; + state.ParseFromString(str); + // TODO(zhihong) : add lr_policy DeserializeState + num_sample_passed_ = state.num_sample_passed(); + + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.accum_gradient(), accum_gradient_); + ProtoToTensor(state.accum_delta(), accum_delta_); + ProtoToTensor(state.update_delta(), update_delta_); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adadelta_optimizer.h b/paddle/optimizer/adadelta_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5eab097f57d049855dd171a1aa6f74c48ae0e7 --- /dev/null +++ b/paddle/optimizer/adadelta_optimizer.h @@ -0,0 +1,39 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class AdadeltaOptimizer : public ParameterOptimizer { +public: + AdadeltaOptimizer( + Tensor *parameter, LrPolicy *lr, double rho, double epsilon, double decay) + : ParameterOptimizer(parameter, lr), + accum_gradient_(new Tensor(parameter->size())), + accum_delta_(new Tensor(parameter->size())), + update_delta_(new Tensor(parameter->size())), + rho_(rho), + epsilon_(epsilon), + decay_(decay) {} + + ~AdadeltaOptimizer() { + if (accum_gradient_) delete accum_gradient_; + if (accum_delta_) delete accum_delta_; + if (update_delta_) delete update_delta_; + } + void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); + +private: + Tensor *accum_gradient_; + Tensor *accum_delta_; + Tensor *update_delta_; + double rho_; + double epsilon_; + double decay_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adagrad_optimizer.cc b/paddle/optimizer/adagrad_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..bdaa7877d2bc58c17c51b977852d4b6fec511ed2 --- /dev/null +++ b/paddle/optimizer/adagrad_optimizer.cc @@ -0,0 +1,42 @@ +#include + +#include "adagrad_optimizer.h" + +namespace paddle { +namespace optimizer { + +void AdagradOptimizer::Update(const Tensor* gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + Tensor& param = *parameter_; + Tensor& accum_g = *accum_gradient_; + const Tensor& grad = *gradient; + for (size_t i = 0; i < param.size(); ++i) { + accum_g[i] += grad[i] * grad[i]; + param[i] += learning_rate * grad[i] / std::sqrt(accum_g[i] + epsilon_) + + learning_rate * decay_ * param[i]; + } +} +const char* AdagradOptimizer::SerializeState(int* state_len) { + AdagradOptimizerState state; + // TODO(zhihong) : add lr_policy serialization + state.set_num_sample_passed(num_sample_passed_); + + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void AdagradOptimizer::DeserializeState(const std::string& str) { + AdagradOptimizerState state; + state.ParseFromString(str); + // TODO(zhihong) : add lr_policy DeserializeState + num_sample_passed_ = state.num_sample_passed(); + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.accum_gradient(), accum_gradient_); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adagrad_optimizer.h b/paddle/optimizer/adagrad_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..15d0a965ad0c6967e73b14b465168fa66eb8fba3 --- /dev/null +++ b/paddle/optimizer/adagrad_optimizer.h @@ -0,0 +1,32 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class AdagradOptimizer : public ParameterOptimizer { +public: + AdagradOptimizer(Tensor *parameter, + LrPolicy *lr, + double epsilon, + double decay) + : ParameterOptimizer(parameter, lr), + accum_gradient_(new Tensor(parameter->size())), + epsilon_(epsilon), + decay_(decay) {} + ~AdagradOptimizer() { + if (accum_gradient_) delete accum_gradient_; + } + void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); + +private: + Tensor *accum_gradient_; + double epsilon_; + double decay_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adam_optimizer.cc b/paddle/optimizer/adam_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceab7397d87349c64ca9e5d11990cb38068421be --- /dev/null +++ b/paddle/optimizer/adam_optimizer.cc @@ -0,0 +1,48 @@ +#include "adam_optimizer.h" +#include + +namespace paddle { +namespace optimizer { + +void AdamOptimizer::Update(const Tensor *gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + double coef1 = 1.0 - std::pow(beta_1_, num_sample_passed_); + double coef2 = 1.0 - std::pow(beta_2_, num_sample_passed_); + learning_rate *= std::sqrt(coef2) / coef1; + Tensor ¶m = *parameter_; + const Tensor &grad = *gradient; + Tensor &m = *momentums_; + Tensor &v = *velocitys_; + for (size_t i = 0; i < param.size(); ++i) { + m[i] = beta_1_ * m[i] + (1.0 - beta_1_) * grad[i]; + v[i] = beta_2_ * v[i] + (1.0 - beta_2_) * grad[i] * grad[i]; + param[i] -= + learning_rate * (m[i] / std::sqrt(v[i] + epsilon_) + decay_ * param[i]); + } +} + +const char *AdamOptimizer::SerializeState(int *state_len) { + AdamOptimizerState state; + // TODO(zhihong) : add lr_policy serialization + state.set_num_sample_passed(num_sample_passed_); + TensorToProto(*parameter_, state.mutable_parameter()); + TensorToProto(*momentums_, state.mutable_momentums()); + TensorToProto(*velocitys_, state.mutable_velocitys()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void AdamOptimizer::DeserializeState(const std::string &str) { + AdamOptimizerState state; + state.ParseFromString(str); + // TODO(zhihong) : add lr_policy DeserializeState + num_sample_passed_ = state.num_sample_passed(); + + ProtoToTensor(state.parameter(), parameter_); + ProtoToTensor(state.momentums(), momentums_); + ProtoToTensor(state.velocitys(), velocitys_); +} +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/adam_optimizer.h b/paddle/optimizer/adam_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..0ea4c8bb8470504282b4d6c12039791ce896e401 --- /dev/null +++ b/paddle/optimizer/adam_optimizer.h @@ -0,0 +1,41 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class AdamOptimizer : public ParameterOptimizer { +public: + AdamOptimizer(Tensor *parameter, + LrPolicy *lr, + double beta_1, + double beta_2, + double epsilon, + double decay) + : ParameterOptimizer(parameter, lr), + momentums_(new Tensor(parameter->size())), + velocitys_(new Tensor(parameter->size())), + beta_1_(beta_1), + beta_2_(beta_2), + epsilon_(epsilon), + decay_(decay) {} + ~AdamOptimizer() { + if (momentums_) delete momentums_; + if (velocitys_) delete velocitys_; + } + void Update(const Tensor *gradient); + const char *SerializeState(int *state_len); + void DeserializeState(const std::string &state); + +private: + Tensor *momentums_; + Tensor *velocitys_; + double beta_1_; + double beta_2_; + double epsilon_; + double decay_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/lr_policy.h b/paddle/optimizer/lr_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..d8e33ad37ab4c019a36f63f34babe65cf8c8fb16 --- /dev/null +++ b/paddle/optimizer/lr_policy.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include "OptimizerConfig.pb.h" + +namespace paddle { +namespace optimizer { + +class LrPolicy { +public: + virtual ~LrPolicy() {} + virtual double LearningRate(const uint64_t num_sample_passed) = 0; + virtual const char *SerializeState(int *state_len) = 0; + virtual void DeserializeState(const std::string &state) = 0; +}; + +// constant learning rate policy +class ConstLr final : public LrPolicy { +public: + ConstLr(double lr) : learning_rate(lr){}; + double LearningRate(const uint64_t num_sample_passed) { + return learning_rate; + } + const char *SerializeState(int *state_len) { return nullptr; } + void DeserializeState(const std::string &state) {} + +private: + double learning_rate; +}; + +class LinearLr final : public LrPolicy { +public: + LinearLr(double lr, double lr_decay_a, double lr_decay_b) + : learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {} + double LearningRate(const uint64_t num_sample_passed) { + return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); + } + const char *SerializeState(int *state_len) { + // TODO(zhihong) : add lr_policy serialization + return nullptr; + } + void DeserializeState(const std::string &state) { + // TODO(zhihong) : add lr_policy serialization + } + +private: + double learning_rate; + double lr_decay_a; + double lr_decay_b; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/optimizer.cc b/paddle/optimizer/optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..54662dc37891d3211950453b210db4b475837df4 --- /dev/null +++ b/paddle/optimizer/optimizer.cc @@ -0,0 +1,83 @@ +#include "optimizer.h" +#include + +#include "parameter_optimizer.h" + +using namespace paddle; +using namespace paddle::optimizer; + +template +struct EnumToType {}; + +template +struct TypeToEnum {}; + +#define MATCH_ENUM_TYPE(TYPE, ENUM) \ + template <> \ + struct TypeToEnum { \ + static paddle_element_type v() { return ENUM; }; \ + static constexpr TYPE value = ENUM; \ + }; \ + template <> \ + struct EnumToType { \ + typedef TYPE Type; \ + } + +MATCH_ENUM_TYPE(int32_t, PADDLE_ELEMENT_TYPE_INT32); +MATCH_ENUM_TYPE(uint32_t, PADDLE_ELEMENT_TYPE_UINT32); +MATCH_ENUM_TYPE(int64_t, PADDLE_ELEMENT_TYPE_INT64); +MATCH_ENUM_TYPE(uint64_t, PADDLE_ELEMENT_TYPE_UINT64); +// TODO(zhihong): only implement below type, need to fix +MATCH_ENUM_TYPE(float, PADDLE_ELEMENT_TYPE_FLOAT32); +MATCH_ENUM_TYPE(double, PADDLE_ELEMENT_TYPE_FLOAT64); + +struct paddle_optimizer { + paddle::optimizer::ParameterOptimizer* impl; +}; + +paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, + const int config_proto_len, + const paddle_element_type data_type, + void* param_buffer, + int num_bytes, + const char* state, + const int state_len) { + paddle_optimizer* optimizer = new paddle_optimizer; + std::string config(config_proto, config_proto + config_proto_len); + Tensor* parameter = + new Tensor(reinterpret_cast(param_buffer), num_bytes); + optimizer->impl = ParameterOptimizer::Create(config, parameter); + if (state != nullptr) { + std::string s(state, state + state_len); + optimizer->impl->DeserializeState(s); + } + return optimizer; +} + +int paddle_release_optimizer(paddle_optimizer* o) { + if (o != nullptr) delete o->impl; + return PADDLE_SUCCESS; +} + +int paddle_update_parameter(paddle_optimizer* o, + const paddle_element_type data_type, + const void* grad_buffer, + int num_bytes) { + // TOOD(zhihong): datatype not work. need to add the runtime datatype + auto grad_type = reinterpret_cast(grad_buffer); + Tensor* gradient = new Tensor(const_cast(grad_type), num_bytes); + o->impl->Update(gradient); + return PADDLE_SUCCESS; +} + +int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer) { + int param_size = 0; + *param_buffer = (void*)o->impl->get_weight(¶m_size); + return param_size; +} + +int paddle_optimizer_get_state(paddle_optimizer* o, const char** state) { + int state_len = 0; + *state = o->impl->SerializeState(&state_len); + return state_len; +} diff --git a/paddle/optimizer/optimizer.h b/paddle/optimizer/optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..aabf7a458dd30092ed1e522c4d88c6cfe63fcce1 --- /dev/null +++ b/paddle/optimizer/optimizer.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include + +/** + * @brief optimizer library in independent with other module + * which will be used in : + * Case A, the gradient optimized locally on the trainer. + * + * Case B, the gradient optimized on the parameter server. + */ + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + PADDLE_ELEMENT_TYPE_INT32 = 0, + PADDLE_ELEMENT_TYPE_UINT32 = 1, + PADDLE_ELEMENT_TYPE_INT64 = 2, + PADDLE_ELEMENT_TYPE_UINT64 = 3, + PADDLE_ELEMENT_TYPE_FLOAT32 = 4, + PADDLE_ELEMENT_TYPE_FLOAT64 = 5, +} paddle_element_type; + +/** + * @brief execution status code + */ +const int32_t PADDLE_SUCCESS = 0; +const int32_t PADDLE_ERROR = -1; + +typedef struct paddle_optimizer paddle_optimizer; +/** + * this group interface called in order : + * 1. create optimizer with config + * 2. set weights + * 3. update_parameter + * 4. get_weights + * 5. release optimizer + */ + +/** + * @brief create optimizer with proto_config + * @param config_proto, optimizer protobuf, see OptimizerConfig.proto in detail + * @return return optimizer instance + */ +paddle_optimizer* paddle_create_optimizer(const unsigned char* config_proto, + const int config_proto_len, + const paddle_element_type data_type, + void* param_buffer, + int num_bytes, + const char* state, + const int state_len); + +/** + * @brief release optimizer + * @param optimizer + * @return return exec status + */ +int paddle_release_optimizer(paddle_optimizer* o); + +/** + * @brief optimizer instance + * @param datatype of gradient and parameter + * @param gradient, calculate by optimzizer caller. + * TODO(zhihong): just pass loss to reduce communicate overhead. + * Project Adam Ms'14 paper for detail + * @param num_bytes, gradient size + * @return return exec status + */ +int paddle_update_parameter(paddle_optimizer* o, + const paddle_element_type data_type, + const void* gradient, + int num_bytes); + +/** + * @brief optimizer for get parameter buffer + * @param param_buffer, initilized parameter buffer + * @return return content length + */ +int paddle_optimizer_get_weights(paddle_optimizer* o, void** param_buffer); + +/** + * @brief optimzizer for saving training state + * @param training state for receive SerializeState + * @return return state_buffer length + */ +int paddle_optimizer_get_state(paddle_optimizer* o, const char** state); + +#ifdef __cplusplus +} +#endif diff --git a/paddle/optimizer/parameter_optimizer.cc b/paddle/optimizer/parameter_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6218037925649e741d17f49af972ce2d50f8d3d --- /dev/null +++ b/paddle/optimizer/parameter_optimizer.cc @@ -0,0 +1,74 @@ +#include +#include "adadelta_optimizer.h" +#include "adagrad_optimizer.h" +#include "adam_optimizer.h" +#include "lr_policy.h" +#include "sgd_optimizer.h" + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +ParameterOptimizer *ParameterOptimizer::Create(const std::string &config_proto, + Tensor *parameter) { + paddle::OptimizerConfig config; + CHECK(config.ParseFromString(config_proto) == true) + << "failed parse optimizer config"; + auto select_lr_policy = [=](const OptimizerConfig &config) -> LrPolicy * { + if (config.lr_policy() == OptimizerConfig::Const) + return new ConstLr(config.const_lr().learning_rate()); + if (config.lr_policy() == OptimizerConfig::Linear) + return new LinearLr(config.linear_lr().learning_rate(), + config.linear_lr().lr_decay_a(), + config.linear_lr().lr_decay_b()); + // default + LOG(WARNING) << " have not select any LrPolicy. use ConstLr in default"; + return new ConstLr(0.1); + }; + + LrPolicy *lr = select_lr_policy(config); + auto select_optimizer = [=]( + Tensor *parameter, + const OptimizerConfig &config) -> ParameterOptimizer * { + if (config.optimizer() == OptimizerConfig::SGD) { + return new SGDOptimizer(parameter, + lr, + config.sgd().momentum(), + config.sgd().decay(), + config.sgd().nesterov()); + } + if (config.optimizer() == OptimizerConfig::Adadelta) { + return new AdadeltaOptimizer(parameter, + lr, + config.adadelta().rho(), + config.adadelta().epsilon(), + config.adadelta().decay()); + } + if (config.optimizer() == OptimizerConfig::Adagrad) { + return new AdagradOptimizer( + parameter, lr, config.adagrad().epsilon(), config.adagrad().decay()); + } + if (config.optimizer() == OptimizerConfig::Adam) { + return new AdamOptimizer(parameter, + lr, + config.adam().beta_1(), + config.adam().beta_2(), + config.adam().epsilon(), + config.adam().decay()); + } + // default + LOG(WARNING) + << "have not select any Optimizer. use SGDOptimizer in default"; + return new SGDOptimizer(parameter, lr, 0.0, 0.0, false); + }; + return select_optimizer(parameter, config); +} + +float *ParameterOptimizer::get_weight(int *param_size) const { + *param_size = (int)parameter_->size(); + return parameter_->get_buffer(); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/parameter_optimizer.h b/paddle/optimizer/parameter_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..d89c9abb791f947172078d4dce5b1c366852591b --- /dev/null +++ b/paddle/optimizer/parameter_optimizer.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include "OptimizerConfig.pb.h" +#include "lr_policy.h" +#include "serialization.h" +#include "tensor.h" + +namespace paddle { +namespace optimizer { + +class ParameterOptimizer { +public: + /** + * @brief update hook for algorithm need to traverse parameter more than + * once. + */ + ParameterOptimizer(Tensor *parameter, LrPolicy *lr) + : parameter_(parameter), lr_policy_(lr), num_sample_passed_(0) {} + virtual ~ParameterOptimizer() { + delete parameter_; + delete lr_policy_; + } + + static ParameterOptimizer *Create(const std::string &config_proto, + Tensor *parameter); + virtual void Update(const Tensor *gradient) = 0; + virtual float *get_weight(int *param_size) const; + virtual const char *SerializeState(int *state_len) = 0; + virtual void DeserializeState(const std::string &state) = 0; + +protected: + Tensor *parameter_; + // learning rate policy + LrPolicy *lr_policy_; + uint64_t num_sample_passed_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e6254d9e4dab48279b4a880695959526d30d70c --- /dev/null +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -0,0 +1,107 @@ +#include "parameter_optimizer.h" +#include +#include +#include +#include "gtest/gtest.h" +#include "lr_policy.h" + +using namespace paddle; +using namespace paddle::optimizer; + +Tensor* FillTensor(size_t size) { + Tensor* param = new Tensor(size); + Tensor& p = *param; + for (size_t i = 0; i < p.size(); ++i) { + p[i] = (float)rand() / (float)RAND_MAX; + } + return param; +} + +Tensor* FixedTensor(size_t size) { + Tensor* param = new Tensor(size); + Tensor& p = *param; + for (size_t i = 0; i < p.size(); ++i) { + p[i] = i; + } + return param; +} + +class OptimizerTest : public testing::Test { +public: + // init tensor shape + const size_t kSize = 5; + + virtual void SetUp() { + CreateSGD(); + CreateAdam(); + } + virtual void TearDown() {} + + void CreateSGD() { + Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(OptimizerConfig::SGD); + config_.mutable_sgd()->set_momentum(0.0); + config_.mutable_sgd()->set_decay(0.0); + config_.mutable_sgd()->set_nesterov(false); + config_.set_lr_policy(OptimizerConfig::Const); + config_.mutable_const_lr()->set_learning_rate(0.1); + std::string str = config_.SerializeAsString(); + ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); + opts_.push_back(opt); + } + + void CreateAdam() { + Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(OptimizerConfig::Adam); + config_.mutable_adam()->set_beta_1(0.9); + config_.mutable_adam()->set_beta_2(0.1); + config_.mutable_adam()->set_epsilon(1e-3); + config_.mutable_adam()->set_decay(0.0); + config_.set_lr_policy(OptimizerConfig::Const); + config_.mutable_const_lr()->set_learning_rate(0.1); + std::string str = config_.SerializeAsString(); + ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); + opts_.push_back(opt); + } + + void TestGetWeight() { + Tensor* p = FixedTensor(kSize); + for (size_t i = 0; i < opts_.size(); ++i) { + int s = 0; + float* newp = (float*)opts_[i]->get_weight(&s); + for (size_t j = 0; j < kSize; ++j) { + EXPECT_EQ(newp[j], (*p)[j]); + } + } + } + + void TestUpdate() { + Tensor* g = FixedTensor(kSize); + for (size_t i = 0; i < opts_.size(); ++i) { + opts_[i]->Update(g); + } + } + + void TestCheckPoint() { + for (size_t i = 0; i < opts_.size(); ++i) { + int state_len = 0; + std::string state = opts_[i]->SerializeState(&state_len); + opts_[i]->DeserializeState(state); + } + } + +private: + std::vector opts_; + OptimizerConfig config_; +}; + +TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); } + +TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } + +TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); } + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/optimizer/serialization.h b/paddle/optimizer/serialization.h new file mode 100644 index 0000000000000000000000000000000000000000..92fbf65cc6b98d7f92841bafe4ab77001ca03b7c --- /dev/null +++ b/paddle/optimizer/serialization.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include +#include "OptimizerConfig.pb.h" +#include "paddle/utils/Logging.h" +#include "tensor.h" + +namespace paddle { +namespace optimizer { + +static void TensorToProto(const Tensor& tensor, TensorProto* proto) { + proto->set_data_type(TensorProto::PADDLE_ELEMENT_TYPE_FLOAT32); + std::stringstream os; + for (size_t i = 0; i < tensor.size(); ++i) { + os << tensor[i]; + proto->add_content(os.str()); + os.str(std::string()); + } +} + +static void ProtoToTensor(const TensorProto& proto, Tensor* tensor) { + std::stringstream sin; + for (auto i = 0; i < proto.content_size(); ++i) { + sin << proto.content(i); + sin >> (*tensor)[i]; + sin.str(std::string()); + sin.clear(); + } +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/serialization_test.cpp b/paddle/optimizer/serialization_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2454140dc243b40ed8348578360b30894213838 --- /dev/null +++ b/paddle/optimizer/serialization_test.cpp @@ -0,0 +1,25 @@ +#include "serialization.h" +#include "gtest/gtest.h" + +using namespace paddle; +using namespace paddle::optimizer; + +TEST(TensorToProto, Case1) { + Tensor t(3), t1(3); + for (size_t i = 0; i < t.size(); ++i) { + t[i] = i; + t1[i] = 0; + } + + TensorProto proto; + TensorToProto(t, &proto); + ProtoToTensor(proto, &t1); + for (size_t i = 0; i < t1.size(); ++i) { + EXPECT_EQ(t1[i], t[i]); + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/paddle/optimizer/sgd_optimizer.cc b/paddle/optimizer/sgd_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..34e051003fa83f11b1f4a39c46856e0372836a1a --- /dev/null +++ b/paddle/optimizer/sgd_optimizer.cc @@ -0,0 +1,49 @@ +#include "sgd_optimizer.h" +#include "serialization.h" + +namespace paddle { +namespace optimizer { + +void SGDOptimizer::Update(const Tensor *gradient) { + num_sample_passed_ += 1; + double learning_rate = lr_policy_->LearningRate(num_sample_passed_); + float velocity = 0.0; + Tensor ¶m = *parameter_; + const Tensor &grad = *gradient; + Tensor &m = *momentums_; + for (size_t i = 0; i < param.size(); ++i) { + if (momentum_ == 0.0) { + velocity = -learning_rate * grad[i] - learning_rate * decay_ * param[i]; + } else { + m[i] = momentum_ * m[i] - learning_rate * grad[i] - + learning_rate * decay_ * param[i]; + velocity = m[i]; + } + if (nesterov_) { + param[i] += momentum_ * velocity - learning_rate * grad[i]; + } else { + param[i] += velocity; + } + } +} + +const char *SGDOptimizer::SerializeState(int *state_len) { + SGDOptimizerState state; + state.set_num_sample_passed(num_sample_passed_); + TensorToProto(*parameter_, state.mutable_parameter()); + if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums()); + auto str = state.SerializeAsString(); + *state_len = str.size(); + return str.c_str(); +} + +void SGDOptimizer::DeserializeState(const std::string &str) { + SGDOptimizerState state; + state.ParseFromString(str); + num_sample_passed_ = state.num_sample_passed(); + ProtoToTensor(state.parameter(), parameter_); + if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_); +} + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/sgd_optimizer.h b/paddle/optimizer/sgd_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..b74a902e1aa40a7831b36ab826d72372a3588bcf --- /dev/null +++ b/paddle/optimizer/sgd_optimizer.h @@ -0,0 +1,37 @@ +#pragma once + +#include "parameter_optimizer.h" + +namespace paddle { +namespace optimizer { + +class SGDOptimizer : public ParameterOptimizer { +public: + SGDOptimizer(Tensor* parameter, LrPolicy* lr, double m, double d, bool n) + : ParameterOptimizer(parameter, lr), + momentums_(nullptr), + momentum_(m), + decay_(d), + nesterov_(n) { + if (momentum_ != 0.0) { + size_t size = parameter->size(); + // TODO: fix it with align aware allocator bind to Tensor + momentums_ = new Tensor(size); + } + } + virtual ~SGDOptimizer() { + if (momentums_) delete momentums_; + } + void Update(const Tensor* gradient); + const char* SerializeState(int* state_len); + void DeserializeState(const std::string& state); + +private: + Tensor* momentums_; + double momentum_; + double decay_; + bool nesterov_; +}; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/optimizer/tensor.h b/paddle/optimizer/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..80a8c93081ea7758d3b5ba016a14d424954db913 --- /dev/null +++ b/paddle/optimizer/tensor.h @@ -0,0 +1,54 @@ +#pragma once +/** + * @brief tensor used by optimizer + */ + +#include +#include +#include "paddle/utils/Common.h" +#include "paddle/utils/Logging.h" + +namespace paddle { +namespace optimizer { + +template +class TensorT { +public: + TensorT(size_t size) : height_(1), width_(size) { + data_ptr_ = std::shared_ptr(new T[size], std::default_delete()); + data_ = data_ptr_.get(); + } + + TensorT(T* data, size_t size) + : height_(1), width_(size), data_ptr_(nullptr), data_(data) {} + + TensorT(T* data, size_t h, size_t w) + : height_(h), width_(w), data_ptr_(nullptr), data_(data) {} + + virtual ~TensorT() {} + + T* get_buffer() { return this->data_; } + + T& operator[](const size_t idx) { + CHECK(idx >= 0 && idx < this->width_) << "out of index range"; + return data_[idx]; + } + T& operator[](const size_t idx) const { + CHECK(idx >= 0 && idx < this->width_) << "out of index range"; + return data_[idx]; + } + // TODO: replace with tensorshape + size_t size() const { return this->width_ * this->height_; } + +protected: + size_t height_; + size_t width_; + std::shared_ptr data_ptr_; + T* data_; +}; + +// TODO(zhihong): design problem of dynamic datatype, need to fix it +typedef TensorT Tensor; + +} // namespace optimizer +} // namespace paddle diff --git a/paddle/parameter/Argument.h b/paddle/parameter/Argument.h index 91aca98e186aef0ad6b345cf4791ef80c616e3fe..09bd633616730dc9475edc596128166f4f70b0cd 100644 --- a/paddle/parameter/Argument.h +++ b/paddle/parameter/Argument.h @@ -149,6 +149,7 @@ struct Argument { : getBatchSize(); } + bool hasSeq() const { return sequenceStartPositions != nullptr; } bool hasSubseq() const { return subSequenceStartPositions != nullptr; } const int* getCpuStartPositions() const { diff --git a/paddle/platform/.clang-format b/paddle/platform/.clang-format new file mode 100644 index 0000000000000000000000000000000000000000..29282dc87e2c499988c17d90d47d44cd5cf7f115 --- /dev/null +++ b/paddle/platform/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/paddle/majel/CMakeLists.txt b/paddle/platform/CMakeLists.txt similarity index 51% rename from paddle/majel/CMakeLists.txt rename to paddle/platform/CMakeLists.txt index 93e5e2c22f0eb5797c635efd8ca34ffb74c03311..c7d7b14518ebb8415014a78fc1a3bafa8c386191 100644 --- a/paddle/majel/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,8 +1,4 @@ +nv_test(cuda_test SRCS cuda_test.cu) + cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) - -cc_library(ddim SRCS ddim.cc) -cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) - -nv_test(cuda_test SRCS cuda_test.cu) -nv_test(dim_test SRCS dim_test.cu DEPS ddim) diff --git a/paddle/platform/assert.h b/paddle/platform/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..70d3bf75062c5471ab54ee2c4c7637c679d9a8a3 --- /dev/null +++ b/paddle/platform/assert.h @@ -0,0 +1,29 @@ +#pragma once + +#define STRINGIFY(x) #x +#define TOSTRING(x) STRINGIFY(x) + +#if defined(__APPLE__) && defined(__CUDA_ARCH__) && !defined(NDEBUG) +#include +#define PADDLE_ASSERT(e) \ + do { \ + if (!(e)) { \ + printf("%s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \ + TOSTRING(e)); \ + asm("trap;"); \ + } \ + } while (0) + +#define PADDLE_ASSERT_MSG(e, m) \ + do { \ + if (!(e)) { \ + printf("%s:%d Assertion `%s` failed (%s).\n", __FILE__, __LINE__, \ + TOSTRING(e), m); \ + asm("trap;"); \ + } \ + } while (0) +#else +#include +#define PADDLE_ASSERT(e) assert(e) +#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m)) +#endif diff --git a/paddle/majel/cuda_test.cu b/paddle/platform/cuda_test.cu similarity index 100% rename from paddle/majel/cuda_test.cu rename to paddle/platform/cuda_test.cu diff --git a/paddle/majel/detail/hostdevice.h b/paddle/platform/hostdevice.h similarity index 100% rename from paddle/majel/detail/hostdevice.h rename to paddle/platform/hostdevice.h diff --git a/paddle/platform/place.cc b/paddle/platform/place.cc new file mode 100644 index 0000000000000000000000000000000000000000..1afd03c01169d395b086c1da458ce25c66a12a51 --- /dev/null +++ b/paddle/platform/place.cc @@ -0,0 +1,46 @@ +#include "paddle/platform/place.h" + +namespace paddle { +namespace platform { + +namespace detail { + +class PlacePrinter : public boost::static_visitor<> { + public: + PlacePrinter(std::ostream &os) : os_(os) {} + void operator()(const CpuPlace &) { os_ << "CpuPlace"; } + void operator()(const GpuPlace &p) { os_ << "GpuPlace(" << p.device << ")"; } + + private: + std::ostream &os_; +}; + +} // namespace detail + +static Place the_default_place; + +void set_place(const Place &place) { the_default_place = place; } +const Place &get_place() { return the_default_place; } + +const GpuPlace default_gpu() { return GpuPlace(0); } +const CpuPlace default_cpu() { return CpuPlace(); } + +bool is_gpu_place(const Place &p) { + return boost::apply_visitor(IsGpuPlace(), p); +} +bool is_cpu_place(const Place &p) { + return !boost::apply_visitor(IsGpuPlace(), p); +} + +bool places_are_same_class(const Place &p1, const Place &p2) { + return is_gpu_place(p1) == is_gpu_place(p2); +} + +std::ostream &operator<<(std::ostream &os, const Place &p) { + detail::PlacePrinter printer(os); + boost::apply_visitor(printer, p); + return os; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/place.h b/paddle/platform/place.h new file mode 100644 index 0000000000000000000000000000000000000000..489572c526e162500c8f747f0ec8df10da9d86a2 --- /dev/null +++ b/paddle/platform/place.h @@ -0,0 +1,49 @@ +#pragma once +#include +#include + +namespace paddle { +namespace platform { + +struct CpuPlace { + // WORKAROUND: for some reason, omitting this constructor + // causes errors with boost 1.59 and OSX + CpuPlace() {} + + // needed for variant equality comparison + inline bool operator==(const CpuPlace &) const { return true; } + inline bool operator!=(const CpuPlace &) const { return false; } +}; + +struct GpuPlace { + GpuPlace() : GpuPlace(0) {} + GpuPlace(int d) : device(d) {} + + // needed for variant equality comparison + inline bool operator==(const GpuPlace &o) const { return device == o.device; } + inline bool operator!=(const GpuPlace &o) const { return !(*this == o); } + + int device; +}; + +struct IsGpuPlace : public boost::static_visitor { + bool operator()(const CpuPlace &) const { return false; } + bool operator()(const GpuPlace &gpu) const { return true; } +}; + +typedef boost::variant Place; + +void set_place(const Place &); +const Place &get_place(); + +const GpuPlace default_gpu(); +const CpuPlace default_cpu(); + +bool is_gpu_place(const Place &); +bool is_cpu_place(const Place &); +bool places_are_same_class(const Place &, const Place &); + +std::ostream &operator<<(std::ostream &, const Place &); + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/place_test.cc b/paddle/platform/place_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..73fccceedf6918148a26100f64cf322305c3ac20 --- /dev/null +++ b/paddle/platform/place_test.cc @@ -0,0 +1,40 @@ +#include "paddle/platform/place.h" +#include +#include "gtest/gtest.h" + +TEST(Place, Equality) { + paddle::platform::CpuPlace cpu; + paddle::platform::GpuPlace g0(0), g1(1), gg0(0); + + EXPECT_EQ(cpu, cpu); + EXPECT_EQ(g0, g0); + EXPECT_EQ(g1, g1); + EXPECT_EQ(g0, gg0); + + EXPECT_NE(g0, g1); + + EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0)); + EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu)); +} + +TEST(Place, Default) { + EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::get_place())); + EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); + EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); + + paddle::platform::set_place(paddle::platform::CpuPlace()); + EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); +} + +TEST(Place, Print) { + { + std::stringstream ss; + ss << paddle::platform::GpuPlace(1); + EXPECT_EQ("GpuPlace(1)", ss.str()); + } + { + std::stringstream ss; + ss << paddle::platform::CpuPlace(); + EXPECT_EQ("CpuPlace", ss.str()); + } +} diff --git a/paddle/trainer/CMakeLists.txt b/paddle/trainer/CMakeLists.txt index 9d246b6690134d96e9a262c6ac64d998536128a9..f34d53ae99f913a8aed8767b7271a538efce4778 100644 --- a/paddle/trainer/CMakeLists.txt +++ b/paddle/trainer/CMakeLists.txt @@ -26,6 +26,13 @@ set(TRAINER_HEADERS ThreadParameterUpdater.h TrainerConfigHelper.h) +if(NOT WITH_GOLANG) + list(REMOVE_ITEM TRAINER_SOURCES + NewRemoteParameterUpdater.cpp) + list(REMOVE_ITEM TRAINER_HEADERS + NewRemoteParameterUpdater.h) +endif() + add_library(paddle_trainer_lib STATIC ${TRAINER_SOURCES}) @@ -34,7 +41,7 @@ add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib ${TRAINER_HEADERS}) add_dependencies(paddle_trainer_lib - gen_proto_cpp paddle_pserver_cclient_lib) + gen_proto_cpp) macro(add_paddle_exe TARGET_NAME) add_executable(${TARGET_NAME} ${ARGN}) @@ -63,5 +70,8 @@ if(APPLE) set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") endif() -target_link_libraries(paddle_trainer ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a) -target_link_libraries(paddle_trainer_lib ${CMAKE_CURRENT_SOURCE_DIR}/libpaddle_pserver_cclient.a) +if(WITH_GOLANG) + add_dependencies(paddle_trainer_lib paddle_pserver_cclient) + target_link_libraries(paddle_trainer ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a) + target_link_libraries(paddle_trainer_lib ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a) +endif(WITH_GOLANG) diff --git a/paddle/trainer/tests/test_recurrent_machine_generation.cpp b/paddle/trainer/tests/test_recurrent_machine_generation.cpp index 03446b3b2f6d5ff42fbf0d735a24d88bd0429747..1322e77178a4f5674f41943f886a17be8337bd75 100644 --- a/paddle/trainer/tests/test_recurrent_machine_generation.cpp +++ b/paddle/trainer/tests/test_recurrent_machine_generation.cpp @@ -124,6 +124,8 @@ TEST(RecurrentGradientMachine, test_generation) { bool beam_search) { FLAGS_config_args = beam_search ? "beam_search=1" : "beam_search=0"; for (auto useGpu : useGpuConfs) { + LOG(INFO) << configFile << " useGpu=" << useGpu + << " beam_search=" << beam_search; testGeneration(configFile, useGpu, hasSubseq, expRetFile); } }; diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt index 62d5b9e38b21ee82d1e78c3bde5aa5df7e4a33ee..c942620990765832f21c887d30f85a2d211a5f32 100644 --- a/proto/CMakeLists.txt +++ b/proto/CMakeLists.txt @@ -5,6 +5,7 @@ set(proto_filenames ParameterConfig.proto ParameterService.proto TrainerConfig.proto + OptimizerConfig.proto ParameterServerConfig.proto) set(PROTO_GEN) @@ -35,10 +36,8 @@ foreach(filename ${proto_filenames}) DEPENDS ${filename} ${external_project_dependencies}) endforeach() -include_directories(${CMAKE_CURRENT_BINARY_DIR}/proto) - add_custom_target(gen_proto_cpp ALL DEPENDS ${PROTO_GEN}) add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) -add_library(paddle_proto STATIC - ${PROTO_GEN}) + +add_library(paddle_proto STATIC ${PROTO_GEN}) target_include_directories(paddle_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/proto/OptimizerConfig.proto b/proto/OptimizerConfig.proto new file mode 100644 index 0000000000000000000000000000000000000000..c698d3c2ddbf58a41ac6ee960af83a257325d1f9 --- /dev/null +++ b/proto/OptimizerConfig.proto @@ -0,0 +1,154 @@ +syntax = "proto2"; + +option optimize_for = LITE_RUNTIME; + +package paddle; + +message SGDConfig { + // SGD + // momentum: float >= 0. Parameter updates momentum. + // decay: float >= 0. Learning rate decay over each update. + // nesterov: boolean. Whether to apply Nesterov momentum. + optional double momentum = 21 [default = 0.0]; + optional double decay = 23 [default = 0.0]; + optional bool nesterov =24 [default = false]; + +} + + +message AdadeltaConfig { + // Adadelta + // It is recommended to leave it at the default value. + // rho: float >= 0. + // epsilon: float >= 0. Fuzz factor. + // decay: float >= 0. Learning rate decay over each update. + + // reference : [Adadelta - an adaptive learning rate method](http://arxiv.org/abs/1212.5701) + optional double rho = 33 [default = 0.90]; + optional double epsilon = 31 [default = 1e-5]; + optional double decay = 32 [default = 0.0]; + +} + +message AdagradConfig { +// Adagrad +// epsilon: float >= 0. +// decay: float >= 0. Learning rate decay over each update. + +// reference : [Adaptive Subgradient Methods for Online Learning and Stochastic Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + optional double epsilon = 41 [default = 1e-5]; + optional double decay = 42 [default = 0.0]; +} + +message AdamConfig { + // Adaj + // beta_1: float, 0 < beta < 1. Generally close to 1. + // beta_2: float, 0 < beta < 1. Generally close to 1. + // epsilon: float >= 0. Fuzz factor. + // decay: float >= 0. Learning rate decay over each update. + // reference : [Adam - A Method for Stochastic Optimization](http://arxiv.org/abs/1412.6980v8) + optional double beta_1 = 41; + optional double beta_2 = 42; + optional double epsilon = 43; + optional double decay = 44; +} + +message ConstLrConfig { + // learninRate Policy + optional double learning_rate = 1 [default = 1.0]; +} + +message LinearLrConfig { + // learninRate Policy + optional double learning_rate = 1 [default = 1.0]; + optional double lr_decay_a = 2; + optional double lr_decay_b = 3; +} + +message TensorProto { +enum DataType { + PADDLE_ELEMENT_TYPE_INT32 = 0; + PADDLE_ELEMENT_TYPE_UINT32 = 1; + PADDLE_ELEMENT_TYPE_INT64 = 2; + PADDLE_ELEMENT_TYPE_UINT64 = 3; + PADDLE_ELEMENT_TYPE_FLOAT32 = 4; + PADDLE_ELEMENT_TYPE_FLOAT64 = 5; +} + optional DataType data_type = 1; + repeated bytes content = 2; +} + +message SGDOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto momentums = 2; +} + +message AdadeltaOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto accum_gradient = 2; + optional TensorProto accum_delta = 3; + optional TensorProto update_delta = 4; +} + +message AdagradOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto accum_gradient = 2; +} + +message AdamOptimizerState { + // learning rate policy + optional double learning_rate = 101; + optional double lr_decay_a = 102; + optional double lr_decay_b = 103; + optional double num_sample_passed = 104; + // state + optional TensorProto parameter = 1; + optional TensorProto momentums = 2; + optional TensorProto velocitys = 3; +} + +message OptimizerConfig { + enum Optimizer { + SGD = 1; + Adadelta = 2; + Adagrad = 3; + Adam = 4; + } + optional Optimizer optimizer = 1; + optional SGDConfig sgd = 3; + optional AdadeltaConfig adadelta = 4; + optional AdagradConfig adagrad = 5; + optional AdamConfig adam = 6; + + enum LrPolicy { + Const = 0; + Linear = 1; + } + optional LrPolicy lr_policy = 11; + optional ConstLrConfig const_lr = 12; + optional LinearLrConfig linear_lr = 13; + + // common config of optimizer + // gradient clip when L2 exceeding value + optional double clip_norm = 101; + // gradient clip when L1 exceeding value + optional double clip_value = 102; +} diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 0e17c42d34f147db190ac5e5ccd5339360cc35bb..3640dd3a75ea212a84255ea7f6369b63606482ab 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -18,7 +18,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp - DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies} paddle_master_shared) + DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies}) add_custom_target(paddle_python ALL DEPENDS ${OUTPUT_DIR}/.timestamp) diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index fc2e3bbcde0e94b6325bd0ca1fd41e088df0b950..c11dc09a8b98bb8a3c8455f811b1435714e825d0 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -328,53 +328,33 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name, SubModelBegin(name) g_current_submodel.is_recurrent_layer_group = True g_current_submodel.reversed = seq_reversed - g_current_submodel.target_inlinkid = -1 in_links_count = 0 for linkid, link in enumerate(in_links): if isinstance(link, basestring): name = link - has_subseq = False else: name = link.link_name - has_subseq = link.has_subseq - # assign target_inlinkid according to target_inlinkname - if target_inlinkname == name: - g_current_submodel.target_inlinkid = linkid - if in_links_count == 0: - in_links_has_subseq = has_subseq - else: - config_assert( - in_links_has_subseq == has_subseq, - "The sequence type of in_links should be the same in RecurrentLayerGroup" - ) in_links_count += 1 layer_name = MakeLayerNameInParentSubmodel(name) layer = g_layer_map[layer_name] - if has_subseq: - SequenceScatterAgentLayer(name=name, size=layer.size) - else: - ScatterAgentLayer(name=name, size=layer.size) + ScatterAgentLayer(name=name, size=layer.size) pair = g_current_submodel.in_links.add() pair.layer_name = layer_name pair.link_name = MakeLayerNameInSubmodel(name) - pair.has_subseq = has_subseq @config_func def RecurrentLayerGroupSetOutLink(link): if isinstance(link, basestring): name = link - has_subseq = False else: name = link.link_name - has_subseq = link.has_subseq layer_name = MakeLayerNameInParentSubmodel(name) pair = g_current_submodel.out_links.add() pair.layer_name = MakeLayerNameInSubmodel(name) pair.link_name = layer_name - pair.has_subseq = has_subseq def RecurrentLayerGroupSetGenerator(generator=None): @@ -389,8 +369,7 @@ def RecurrentLayerGroupBegin(name, generator=None, target_inlinkname="", seq_reversed=False): - RecurrentLayerGroupWithoutOutLinksBegin(name, in_links, seq_reversed, - target_inlinkname) + RecurrentLayerGroupWithoutOutLinksBegin(name, in_links, seq_reversed) for link in out_links: RecurrentLayerGroupSetOutLink(link) @@ -425,8 +404,6 @@ def RecurrentLayerGroupEnd(name): agent_name = GetLayerBaseName(pair.link_name) if prev_submodel.HasField("generator"): DataLayer(name=agent_name, size=layer.size) - elif pair.has_subseq: - SequenceGatherAgentLayer(name=agent_name, size=layer.size) else: GatherAgentLayer(name=agent_name, size=layer.size) @@ -1651,8 +1628,14 @@ class SelectiveFCLayer(LayerBase): @config_layer('print') class PrintLayer(LayerBase): - def __init__(self, name, inputs): + def __init__(self, name, inputs, format=None): super(PrintLayer, self).__init__(name, 'print', 0, inputs) + if format is None: + format = "\n".join([ + "layer=" + input.input_layer_name + " %s" + for input in self.inputs + ]) + self.config.user_arg = format @config_layer('priorbox') @@ -1949,7 +1932,6 @@ class BatchNormLayer(LayerBase): def __init__(self, name, inputs, - active_type="linear", bias=True, use_global_stats=True, moving_average_fraction=0.9, @@ -1987,12 +1969,7 @@ class BatchNormLayer(LayerBase): cudnn_version >= 4007 self.layer_type = "cudnn_batch_norm" if use_cudnn else "batch_norm" super(BatchNormLayer, self).__init__( - name, - self.layer_type, - 0, - active_type=active_type, - inputs=inputs, - **xargs) + name, self.layer_type, 0, inputs=inputs, **xargs) if use_global_stats is not None: self.config.use_global_stats = use_global_stats @@ -2253,13 +2230,6 @@ class AgentLayer(LayerBase): name, 'agent', size, inputs=[], device=device) -@config_layer('sequence_agent') -class SequenceAgentLayer(LayerBase): - def __init__(self, name, size, device=None): - super(SequenceAgentLayer, self).__init__( - name, 'sequence_agent', size, inputs=[], device=device) - - @config_layer('gather_agent') class GatherAgentLayer(LayerBase): def __init__(self, name, size, device=None): @@ -2274,20 +2244,6 @@ class ScatterAgentLayer(LayerBase): name, 'scatter_agent', size, inputs=[], device=device) -@config_layer('sequence_gather_agent') -class SequenceGatherAgentLayer(LayerBase): - def __init__(self, name, size, device=None): - super(SequenceGatherAgentLayer, self).__init__( - name, 'sequence_gather_agent', size, inputs=[], device=device) - - -@config_layer('sequence_scatter_agent') -class SequenceScatterAgentLayer(LayerBase): - def __init__(self, name, size, device=None): - super(SequenceScatterAgentLayer, self).__init__( - name, 'sequence_scatter_agent', size, inputs=[], device=device) - - @config_layer('multiplex') class MultiplexLayer(LayerBase): def __init__(self, name, inputs, size, device=None): @@ -2303,12 +2259,12 @@ class MultiplexLayer(LayerBase): @config_func -def Link( - name, - has_subseq=False, ): +def Link(name, has_subseq=False): + """ + Still keeping has_subseq for backward compatibility + """ link_config = LinkConfig() link_config.link_name = name - link_config.has_subseq = has_subseq return link_config @@ -2341,20 +2297,13 @@ def Memory(name, config_assert(name is not None, "name needs cannot be None") memory_name = name + "+delay1" agent_name = memory_name - if is_sequence: - config_assert( - boot_layer is not None, - "there must be boot_layer in network when is_sequence = True") - agent_layer = SequenceAgentLayer(agent_name, size) - else: - agent_layer = AgentLayer(agent_name, size) + agent_layer = AgentLayer(agent_name, size) config_assert(g_current_submodel.is_recurrent_layer_group, 'Memory should be used in recurrent layer group only') memory = g_current_submodel.memories.add() if name is not None: memory.layer_name = MakeLayerNameInSubmodel(name) memory.link_name = MakeLayerNameInSubmodel(agent_name) - memory.is_sequence = is_sequence options = sum((boot_layer is not None, bool(boot_bias), boot_with_const_id is not None)) config_assert( @@ -2428,15 +2377,23 @@ class ExpandLayer(LayerBase): @config_layer('featmap_expand') class FeatMapExpandLayer(LayerBase): - def __init__(self, name, inputs, device=None, num_filters=None, bias=False): + def __init__(self, + name, + inputs, + num_filters=None, + as_row_vector=True, + bias=False, + **xargs): super(FeatMapExpandLayer, self).__init__( - name, 'featmap_expand', 0, inputs=inputs, device=device) + name, 'featmap_expand', 0, inputs=inputs, **xargs) config_assert( len(self.inputs) == 1, 'ExpandLayer takes 1 and only 1 inputs') if num_filters is not None: self.config.num_filters = num_filters else: logger.fatal("FeatMapExpandLayer must specify num_filters.") + if not as_row_vector: + self.config.user_arg = "as_col_vec" self.set_layer_size(self.get_input_layer(0).size * num_filters) @@ -2446,14 +2403,12 @@ class MaxLayer(LayerBase): name, inputs, trans_type='non-seq', - active_type='linear', bias=False, output_max_index=None, **xargs): super(MaxLayer, self).__init__(name, 'max', 0, inputs=inputs, **xargs) config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input') self.config.trans_type = trans_type - self.config.active_type = active_type for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) self.set_layer_size(input_layer.size) @@ -2495,18 +2450,12 @@ class SequenceLastInstanceLayer(LayerBase): def __init__(self, name, inputs, - active_type='linear', trans_type='non-seq', bias=False, stride=-1, **xargs): super(SequenceLastInstanceLayer, self).__init__( - name, - 'seqlastins', - 0, - inputs=inputs, - active_type=active_type, - **xargs) + name, 'seqlastins', 0, inputs=inputs, **xargs) config_assert( len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input') if trans_type == 'seq': @@ -2522,7 +2471,6 @@ class SequenceFirstInstanceLayer(SequenceLastInstanceLayer): def __init__(self, name, inputs, - active_type='linear', trans_type='non-seq', bias=False, stride=-1, @@ -2530,7 +2478,6 @@ class SequenceFirstInstanceLayer(SequenceLastInstanceLayer): super(SequenceFirstInstanceLayer, self).__init__( name, inputs=inputs, - active_type=active_type, trans_type=trans_type, bias=bias, stride=stride, @@ -2540,14 +2487,9 @@ class SequenceFirstInstanceLayer(SequenceLastInstanceLayer): @config_layer('seqconcat') class SequenceConcatLayer(LayerBase): - def __init__(self, name, inputs, active_type='linear', bias=False, **xargs): + def __init__(self, name, inputs, bias=False, **xargs): super(SequenceConcatLayer, self).__init__( - name, - 'seqconcat', - 0, - inputs=inputs, - active_type=active_type, - **xargs) + name, 'seqconcat', 0, inputs=inputs, **xargs) config_assert( len(inputs) == 2, 'SequenceConcatLayer must have 2 inputs') for input_index in xrange(len(self.inputs)): @@ -2558,20 +2500,9 @@ class SequenceConcatLayer(LayerBase): @config_layer('seqreshape') class SequenceReshapeLayer(LayerBase): - def __init__(self, - name, - size, - inputs, - active_type='linear', - bias=False, - **xargs): + def __init__(self, name, size, inputs, bias=False, **xargs): super(SequenceReshapeLayer, self).__init__( - name, - 'seqreshape', - size, - inputs=inputs, - active_type=active_type, - **xargs) + name, 'seqreshape', size, inputs=inputs, **xargs) config_assert( len(inputs) == 1, 'SequenceReshapeLayer must have 1 inputs') self.set_layer_size(size) @@ -2580,9 +2511,9 @@ class SequenceReshapeLayer(LayerBase): @config_layer('subseq') class SubSequenceLayer(LayerBase): - def __init__(self, name, inputs, active_type='linear', bias=False, **xargs): + def __init__(self, name, inputs, bias=False, **xargs): super(SubSequenceLayer, self).__init__( - name, 'subseq', 0, inputs=inputs, active_type=active_type, **xargs) + name, 'subseq', 0, inputs=inputs, **xargs) config_assert(len(inputs) == 3, 'SubSequenceLayer must have 3 inputs') input_layer0 = self.get_input_layer(0) size = input_layer0.size @@ -2738,11 +2669,10 @@ class AverageLayer(LayerBase): inputs, average_strategy='average', trans_type='non-seq', - active_type='linear', bias=False, **xargs): super(AverageLayer, self).__init__( - name, 'average', 0, inputs=inputs, active_type=active_type, **xargs) + name, 'average', 0, inputs=inputs, **xargs) self.config.average_strategy = average_strategy self.config.trans_type = trans_type config_assert(len(inputs) == 1, 'AverageLayer must have 1 input') diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 2d8ddbb9007b241eb1986887d8ea6c2de8235c29..b8ce0373c0e9524518e42ad911fd2cd728facec4 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -311,18 +311,6 @@ class LayerOutput(object): self.outputs = outputs self.reverse = reverse - def __repr__(self): - """ - Disable __repr__ for debug reason. Will be implemented when release - """ - assert False, "this method should not be invoked" - - def __str__(self): - """ - Disable __str__ for debug reason. Will be implemented when release - """ - assert False, "this method should not be invoked" - def set_input(self, input): """ Set the input for a memory layer. Can only be used for memory layer @@ -976,7 +964,7 @@ def fc_layer(input, @wrap_name_default("print") -def printer_layer(input, name=None): +def printer_layer(input, format=None, name=None): """ Print the output value of input layers. This layer is useful for debugging. @@ -994,6 +982,7 @@ def printer_layer(input, name=None): Layer( name=name, + format=format, type=LayerType.PRINT_LAYER, inputs=[l.name for l in input], ) # this layer don't return anything, can not be input of other layer. @@ -1565,14 +1554,24 @@ def expand_layer(input, @wrap_name_default() +@wrap_act_default(act=IdentityActivation()) @layer_support() -def repeat_layer(input, num_repeats, name=None, layer_attr=None): +def repeat_layer(input, + num_repeats, + as_row_vector=True, + act=None, + name=None, + layer_attr=None): """ - A layer for repeating the input for num_repeats times. This is equivalent - to apply concat_layer() with num_repeats same input. + A layer for repeating the input for num_repeats times. + If as_row_vector: .. math:: - y = [x, x, \cdots, x] + y = [x_1,\cdots, x_n, \cdots, x_1, \cdots, x_n] + If not as_row_vector: + .. math:: + y = [x_1,\cdots, x_1, \cdots, x_n, \cdots, x_n] + The example usage is: @@ -1585,6 +1584,14 @@ def repeat_layer(input, num_repeats, name=None, layer_attr=None): :param num_repeats: Repeat the input so many times :type num_repeats: int :param name: Layer name. + :param as_row_vector: True for treating input as row vector and repeating + in the column direction. This is equivalent to apply + concat_layer() with num_repeats same input. + False for treating input as column vector and repeating + in the row direction. + :type as_row_vector: bool + :param act: Activation type. + :type act: BaseActivation :type name: basestring :param layer_attr: extra layer attributes. :type layer_attr: ExtraLayerAttribute. @@ -1595,13 +1602,16 @@ def repeat_layer(input, num_repeats, name=None, layer_attr=None): l = Layer( inputs=[input.name], name=name, + active_type=act.name, num_filters=num_repeats, + as_row_vector=as_row_vector, type=LayerType.FEATURE_MAP_EXPAND_LAYER, **ExtraAttr.to_kwargs(layer_attr)) return LayerOutput( name=name, size=l.config.size, layer_type=LayerType.FEATURE_MAP_EXPAND_LAYER, + activation=act, parents=[input]) @@ -2846,11 +2856,13 @@ def seq_concat_layer(a, b, act=None, name=None, layer_attr=None, Concat sequence a with sequence b. Inputs: - - a = [a1, a2, ..., an] + - a = [a1, a2, ..., am] - b = [b1, b2, ..., bn] - - Note that the length of a and b should be the same. - Output: [a1, b1, a2, b2, ..., an, bn] + Output: [a1, ..., am, b1, ..., bn] + + Note that the above computation is for one sample. Multiple samples are + processed in one batch. The example usage is: @@ -2944,7 +2956,7 @@ def memory(name, :param memory_name: the name of the memory. It is ignored when name is provided. :type memory_name: basestring - :param is_seq: is sequence for boot_layer + :param is_seq: DEPRECATED. is sequence for boot_layer :type is_seq: bool :param boot_layer: boot layer of memory. :type boot_layer: LayerOutput|None @@ -2971,7 +2983,6 @@ def memory(name, memory_name = Memory( name, size, - is_sequence=is_seq, boot_layer=boot_layer.name if boot_layer is not None else None, boot_bias=boot_bias, boot_bias_active_type=boot_bias_active_type.name, @@ -3318,19 +3329,21 @@ class StaticInput(object): """ StaticInput is only used in recurrent_group which defines a read-only memory that can be a sequence or non-sequence. + :param size: DEPRECATED + :param is_seq: DEPRECATED """ def __init__(self, input, is_seq=False, size=None): assert isinstance(input, LayerOutput) self.input = input - self.is_seq = is_seq - assert input.size is not None or size is not None + assert input.size is not None if size is not None: - input.size = size + assert input.size == size -class SubsequenceInput(object): +def SubsequenceInput(input): """ + DEPRECATED. Input sequence has sub-sequence, used in recurrent_group. The example usage is: @@ -3339,11 +3352,7 @@ class SubsequenceInput(object): input = SubsequenceInput(layer) """ - - def __init__(self, input): - assert isinstance(input, LayerOutput) - assert input.size is not None - self.input = input + return input @wrap_name_default("recurrent_group") @@ -3407,7 +3416,8 @@ def recurrent_group(step, input sequence in a reverse order. :type reverse: bool - :param targetInlink: the input layer which share info with layer group's output + :param targetInlink: DEPRECATED. + The input layer which share info with layer group's output Param input specifies multiple input layers. For SubsequenceInput inputs, config should assign one input @@ -3429,46 +3439,21 @@ def recurrent_group(step, model_type('recurrent_nn') def is_single_input(x): - return isinstance(x, LayerOutput) or isinstance(x, StaticInput) \ - or isinstance(x, SubsequenceInput) + return isinstance(x, LayerOutput) or isinstance(x, StaticInput) if is_single_input(input): input = [input] assert isinstance(input, collections.Sequence) def is_in_links(x): - return isinstance(x, LayerOutput) or isinstance(x, SubsequenceInput) + return isinstance(x, LayerOutput) in_links = filter(is_in_links, input) - def targetInlink_in_inlinks(): - for inlink in in_links: - if isinstance(inlink, SubsequenceInput): - if targetInlink == inlink.input: - return True - elif targetInlink == inlink: - return True - return False - - assert (targetInlink == None or targetInlink_in_inlinks()) - targetInlinkName = None if targetInlink == None \ - else targetInlink.name if isinstance(targetInlink, LayerOutput) \ - else targetInlink.input.name - - contains_sub_seq = [False] - - def map_in_links(x): - if isinstance(x, SubsequenceInput): - contains_sub_seq[0] = True - return Link(name=x.input.name, has_subseq=True) - else: - return x.name - RecurrentLayerGroupWithoutOutLinksBegin( name=name, - in_links=map(map_in_links, in_links), - seq_reversed=reverse, - target_inlinkname=targetInlinkName) + in_links=map(lambda x: x.name, in_links), + seq_reversed=reverse) in_args = [] has_LayerOutput = False for each_input in input: @@ -3476,21 +3461,13 @@ def recurrent_group(step, if isinstance(each_input, LayerOutput): in_args.append(each_input) has_LayerOutput = True - elif isinstance(each_input, SubsequenceInput): - in_args.append(each_input.input) - has_LayerOutput = True - else: + else: # StaticInput mem_name = "__%s_memory__" % each_input.input.name mem = memory( - name=mem_name, - is_seq=each_input.is_seq, + name=None, size=each_input.input.size, boot_layer=each_input.input) - with mixed_layer( - name=mem_name, - size=each_input.input.size, - act=IdentityActivation()) as mix: - mix += identity_projection(mem) + mem.set_input(mem) in_args.append(mem) assert (is_generating != has_LayerOutput) @@ -3503,10 +3480,7 @@ def recurrent_group(step, for ot in layer_outs: assert isinstance(ot, LayerOutput) ot.reverse = reverse - if contains_sub_seq[0]: - RecurrentLayerGroupSetOutLink(Link(ot.name, has_subseq=True)) - else: - RecurrentLayerGroupSetOutLink(ot.name) + RecurrentLayerGroupSetOutLink(ot.name) RecurrentLayerGroupEnd(name=name) @@ -5608,13 +5582,13 @@ def row_conv_layer(input, to deploy in an online and low-latency setting. The lookahead convolution incorporates information from future subsequences in a computationally efficient manner to improve unidirectional recurrent neural networks. - + The connection of row convolution is different form the 1D sequence convolution. Assumed that, the future context-length is k, that is to say, it can get the output at timestep t by using the the input feature from t-th timestep to (t+k+1)-th timestep. Assumed that the hidden dim of input activations are d, the activations r_t for the new layer at time-step t are: - + .. math:: r_{t,r} = \sum_{j=1}^{k + 1} {w_{i,j}h_{t+j-1, i}} diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index c24102255f5bbed0f551b2dbfec20be7daf5f5b4..c0e87d6de372dfdd9c7e694af71df8f3b011d43a 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -1,5 +1,5 @@ #!/bin/bash -export configs=(test_fc layer_activations projections test_print_layer +export configs=(test_repeat_layer test_fc layer_activations projections test_print_layer test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/last_first_seq.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/last_first_seq.protostr index 12b2255f3a41119792d0f993ce2e03ce9ee3e994..fee0f8e462bfd211e6aa7698ebfeaf0a19428a62 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/last_first_seq.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/last_first_seq.protostr @@ -9,7 +9,7 @@ layers { name: "__first_seq_0__" type: "seqlastins" size: 30 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data" } @@ -21,7 +21,7 @@ layers { name: "__first_seq_1__" type: "seqlastins" size: 30 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data" } @@ -33,7 +33,7 @@ layers { name: "__last_seq_0__" type: "seqlastins" size: 30 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data" } @@ -44,7 +44,7 @@ layers { name: "__last_seq_1__" type: "seqlastins" size: 30 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data" } @@ -55,7 +55,7 @@ layers { name: "__first_seq_2__" type: "seqlastins" size: 30 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data" } @@ -67,7 +67,7 @@ layers { name: "__last_seq_2__" type: "seqlastins" size: 30 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data" } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_gru.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_gru.protostr index 64530146a1458933d4ba0edffc1b1b7e60a21187..7254deb368963914fd1fff7925b6aeedbed59318 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_gru.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_gru.protostr @@ -123,7 +123,7 @@ layers { name: "__last_seq_0__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__simple_gru_0__" } @@ -134,7 +134,7 @@ layers { name: "__last_seq_1__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__simple_gru_1__" } @@ -256,19 +256,15 @@ sub_models { memories { layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group" link_name: "__simple_gru_0__+delay1@__simple_gru_0___recurrent_group" - is_sequence: false } in_links { layer_name: "__simple_gru_0___transform" link_name: "__simple_gru_0___transform@__simple_gru_0___recurrent_group" - has_subseq: false } out_links { layer_name: "__simple_gru_0__@__simple_gru_0___recurrent_group" link_name: "__simple_gru_0__" - has_subseq: false } - target_inlinkid: -1 } sub_models { name: "__simple_gru_1___recurrent_group" @@ -280,18 +276,14 @@ sub_models { memories { layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group" link_name: "__simple_gru_1__+delay1@__simple_gru_1___recurrent_group" - is_sequence: false } in_links { layer_name: "__simple_gru_1___transform" link_name: "__simple_gru_1___transform@__simple_gru_1___recurrent_group" - has_subseq: false } out_links { layer_name: "__simple_gru_1__@__simple_gru_1___recurrent_group" link_name: "__simple_gru_1__" - has_subseq: false } - target_inlinkid: -1 } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_lstm.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_lstm.protostr index 79fa4c74f081aebadd258e06333de9eafe6a5ee3..7f2aa5a0fea1f4628e4effca5ce9af896f6e6c2c 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_lstm.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/shared_lstm.protostr @@ -205,7 +205,7 @@ layers { name: "__last_seq_0__" type: "seqlastins" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__lstm_group_0__" } @@ -216,7 +216,7 @@ layers { name: "__last_seq_1__" type: "seqlastins" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__lstm_group_1__" } @@ -341,24 +341,19 @@ sub_models { memories { layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0__+delay1@__lstm_group_0___recurrent_group" - is_sequence: false } memories { layer_name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0___state+delay1@__lstm_group_0___recurrent_group" - is_sequence: false } in_links { layer_name: "__mixed_0__" link_name: "__mixed_0__@__lstm_group_0___recurrent_group" - has_subseq: false } out_links { layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0__" - has_subseq: false } - target_inlinkid: -1 } sub_models { name: "__lstm_group_1___recurrent_group" @@ -373,23 +368,18 @@ sub_models { memories { layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group" link_name: "__lstm_group_1__+delay1@__lstm_group_1___recurrent_group" - is_sequence: false } memories { layer_name: "__lstm_group_1___state@__lstm_group_1___recurrent_group" link_name: "__lstm_group_1___state+delay1@__lstm_group_1___recurrent_group" - is_sequence: false } in_links { layer_name: "__mixed_1__" link_name: "__mixed_1__@__lstm_group_1___recurrent_group" - has_subseq: false } out_links { layer_name: "__lstm_group_1__@__lstm_group_1___recurrent_group" link_name: "__lstm_group_1__" - has_subseq: false } - target_inlinkid: -1 } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/simple_rnn_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/simple_rnn_layers.protostr index 68fa881b4f1408b8cd20f2417062ce035c0fda54..0d51f70ee01b913051f7d20547f68a22663200a0 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/simple_rnn_layers.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/simple_rnn_layers.protostr @@ -138,7 +138,7 @@ layers { name: "__last_seq_0__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__recurrent_layer_0__" } @@ -149,7 +149,7 @@ layers { name: "__first_seq_0__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__recurrent_layer_1__" } @@ -161,7 +161,7 @@ layers { name: "__last_seq_1__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__lstmemory_0__" } @@ -172,7 +172,7 @@ layers { name: "__first_seq_1__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__lstmemory_1__" } @@ -184,7 +184,7 @@ layers { name: "__last_seq_2__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__gru_0__" } @@ -195,7 +195,7 @@ layers { name: "__first_seq_2__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__gru_1__" } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_print_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_print_layer.protostr index c402aff174ab7c7d7f63234960d4a24d84622dd4..f4cc492dfb9b5a8c04f6f41cfab017fc613e2a66 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_print_layer.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_print_layer.protostr @@ -12,6 +12,7 @@ layers { inputs { input_layer_name: "input" } + user_arg: "layer=input %s" } input_layer_names: "input" output_layer_names: "input" diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_repeat_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_repeat_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..e012386ff9515947d40ddddb6804de08207e1154 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_repeat_layer.protostr @@ -0,0 +1,42 @@ +type: "nn" +layers { + name: "data" + type: "data" + size: 30 + active_type: "" +} +layers { + name: "__repeat_layer_0__" + type: "featmap_expand" + size: 300 + active_type: "" + inputs { + input_layer_name: "data" + } + num_filters: 10 +} +layers { + name: "__repeat_layer_1__" + type: "featmap_expand" + size: 300 + active_type: "tanh" + inputs { + input_layer_name: "data" + } + num_filters: 10 + user_arg: "as_col_vec" +} +input_layer_names: "data" +output_layer_names: "__repeat_layer_0__" +output_layer_names: "__repeat_layer_1__" +sub_models { + name: "root" + layer_names: "data" + layer_names: "__repeat_layer_0__" + layer_names: "__repeat_layer_1__" + input_layer_names: "data" + output_layer_names: "__repeat_layer_0__" + output_layer_names: "__repeat_layer_1__" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_rnn_group.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_rnn_group.protostr index 77b447aa9db2a6c323fd3c322e7e9ca1ed19a6dd..af1b63c5dfbf0984a20eda02d608f76a454613c6 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_rnn_group.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_rnn_group.protostr @@ -91,7 +91,7 @@ layers { name: "__last_seq_0__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "rnn_forward" } @@ -140,7 +140,7 @@ layers { name: "__first_seq_0__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "rnn_back" } @@ -155,7 +155,7 @@ layers { } layers { name: "sub_seq_input@__recurrent_group_2__" - type: "sequence_scatter_agent" + type: "scatter_agent" size: 100 active_type: "" } @@ -182,7 +182,7 @@ layers { } layers { name: "rnn_subseq_forward" - type: "sequence_gather_agent" + type: "gather_agent" size: 200 active_type: "" } @@ -190,7 +190,7 @@ layers { name: "__last_seq_1__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "rnn_subseq_forward" } @@ -280,7 +280,7 @@ layers { name: "__last_seq_2__" type: "seqlastins" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__lstm_group_0__" } @@ -329,7 +329,7 @@ layers { name: "__last_seq_3__" type: "seqlastins" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__gru_group_0__" } @@ -378,7 +378,7 @@ layers { name: "__last_seq_4__" type: "seqlastins" size: 200 - active_type: "linear" + active_type: "" inputs { input_layer_name: "__fc_layer_0__" } @@ -618,19 +618,15 @@ sub_models { memories { layer_name: "rnn_forward@__recurrent_group_0__" link_name: "rnn_forward+delay1@__recurrent_group_0__" - is_sequence: false } in_links { layer_name: "seq_input" link_name: "seq_input@__recurrent_group_0__" - has_subseq: false } out_links { layer_name: "rnn_forward@__recurrent_group_0__" link_name: "rnn_forward" - has_subseq: false } - target_inlinkid: -1 } sub_models { name: "__recurrent_group_1__" @@ -642,19 +638,15 @@ sub_models { memories { layer_name: "rnn_back@__recurrent_group_1__" link_name: "rnn_back+delay1@__recurrent_group_1__" - is_sequence: false } in_links { layer_name: "seq_input" link_name: "seq_input@__recurrent_group_1__" - has_subseq: false } out_links { layer_name: "rnn_back@__recurrent_group_1__" link_name: "rnn_back" - has_subseq: false } - target_inlinkid: -1 } sub_models { name: "__recurrent_group_2__" @@ -666,19 +658,15 @@ sub_models { memories { layer_name: "rnn_subseq_forward@__recurrent_group_2__" link_name: "rnn_subseq_forward+delay1@__recurrent_group_2__" - is_sequence: false } in_links { layer_name: "sub_seq_input" link_name: "sub_seq_input@__recurrent_group_2__" - has_subseq: true } out_links { layer_name: "rnn_subseq_forward@__recurrent_group_2__" link_name: "rnn_subseq_forward" - has_subseq: true } - target_inlinkid: -1 } sub_models { name: "__lstm_group_0___recurrent_group" @@ -693,24 +681,19 @@ sub_models { memories { layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0__+delay1@__lstm_group_0___recurrent_group" - is_sequence: false } memories { layer_name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0___state+delay1@__lstm_group_0___recurrent_group" - is_sequence: false } in_links { layer_name: "__mixed_0__" link_name: "__mixed_0__@__lstm_group_0___recurrent_group" - has_subseq: false } out_links { layer_name: "__lstm_group_0__@__lstm_group_0___recurrent_group" link_name: "__lstm_group_0__" - has_subseq: false } - target_inlinkid: -1 } sub_models { name: "__gru_group_0___recurrent_group" @@ -722,19 +705,15 @@ sub_models { memories { layer_name: "__gru_group_0__@__gru_group_0___recurrent_group" link_name: "__gru_group_0__+delay1@__gru_group_0___recurrent_group" - is_sequence: false } in_links { layer_name: "__mixed_1__" link_name: "__mixed_1__@__gru_group_0___recurrent_group" - has_subseq: false } out_links { layer_name: "__gru_group_0__@__gru_group_0___recurrent_group" link_name: "__gru_group_0__" - has_subseq: false } - target_inlinkid: -1 } sub_models { name: "__recurrent_group_3__" @@ -746,18 +725,14 @@ sub_models { memories { layer_name: "__fc_layer_0__@__recurrent_group_3__" link_name: "__memory_6__@__recurrent_group_3__" - is_sequence: false } in_links { layer_name: "seq_input" link_name: "seq_input@__recurrent_group_3__" - has_subseq: false } out_links { layer_name: "__fc_layer_0__@__recurrent_group_3__" link_name: "__fc_layer_0__" - has_subseq: false } - target_inlinkid: -1 } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr index 91284b4fb32fcfdbf6b9e7384ffe080574b78821..9d1b41c9d5586235984771d610f5df40a8754522 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_seq_concat_reshape.protostr @@ -27,7 +27,7 @@ layers { name: "__seqreshape_0__" type: "seqreshape" size: 5 - active_type: "linear" + active_type: "" inputs { input_layer_name: "data1" } diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_sequence_pooling.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_sequence_pooling.protostr index 1999c006d237eb449d59c8e8a2a83c1e4fab9d0e..5a217f5544a8a3b4704b158dfeb92f747b7bd94b 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_sequence_pooling.protostr +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_sequence_pooling.protostr @@ -9,7 +9,7 @@ layers { name: "__seq_pooling_0__" type: "max" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } @@ -19,7 +19,7 @@ layers { name: "__seq_pooling_1__" type: "max" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } @@ -29,7 +29,7 @@ layers { name: "__seq_pooling_2__" type: "average" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } @@ -40,7 +40,7 @@ layers { name: "__seq_pooling_3__" type: "average" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } @@ -51,7 +51,7 @@ layers { name: "__seq_pooling_4__" type: "average" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } @@ -62,7 +62,7 @@ layers { name: "__seq_pooling_5__" type: "average" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } @@ -73,7 +73,7 @@ layers { name: "__seq_pooling_6__" type: "max" size: 100 - active_type: "linear" + active_type: "" inputs { input_layer_name: "dat_in" } diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_repeat_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_repeat_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..004e2a5dd4efa9feab7619643673b37fe28146c5 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_repeat_layer.py @@ -0,0 +1,11 @@ +from paddle.trainer_config_helpers import * + +settings(batch_size=1000, learning_rate=1e-5) + +din = data_layer(name='data', size=30) + +outputs( + repeat_layer( + input=din, num_repeats=10, as_row_vector=True), + repeat_layer( + input=din, num_repeats=10, act=TanhActivation(), as_row_vector=False)) diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 102331c0bb6477cbeb618f015aad76a0414723ba..6a1e23a343d6a8de9dbec573f257efb4fb658e92 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -26,7 +26,6 @@ import evaluator from . import dataset from . import reader from . import plot -from . import master import attr import op import pooling @@ -57,7 +56,6 @@ __all__ = [ 'plot', 'evaluator', 'image', - 'master', ] diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 9c614914b5e372e8e5e3c3c072b18b83edf51e87..e09ac1a7a0fe70dbf58a04f51cdf6916485e9be1 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -15,6 +15,7 @@ import requests import hashlib import os +import errno import shutil import sys import importlib @@ -27,7 +28,12 @@ __all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader'] DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') if not os.path.exists(DATA_HOME): - os.makedirs(DATA_HOME) + try: + os.makedirs(DATA_HOME) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass def md5file(fname): diff --git a/python/paddle/v2/layer.py b/python/paddle/v2/layer.py index aeed9ebd7d4d64efa5d0bf1638742a485c0fa44a..bbb9c3ea8c1b389f0ec9fd5ec7be52bd0449f52d 100644 --- a/python/paddle/v2/layer.py +++ b/python/paddle/v2/layer.py @@ -260,7 +260,7 @@ def parse_network(output_layers, extra_layers=None): else: extra_layers = [] - layer_names = __get_used_layers__(output_layers + extra_layers) + layer_names = __get_used_layers__(list(output_layers) + list(extra_layers)) submodel_names = __get_used_submodels__(layer_names) submodel_names.add('root') evaluator_names = __get_used_evaluators__(layer_names) diff --git a/python/setup.py.in b/python/setup.py.in index 8fe1cfd8b338b9b2e47edcec6d66bbcdd38b5198..2e22f640cb55677b6814b7f26a71457a96449de7 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -8,8 +8,7 @@ packages=['paddle', 'paddle.v2', 'paddle.v2.dataset', 'paddle.v2.reader', - 'paddle.v2.plot', - 'paddle.v2.master'] + 'paddle.v2.plot'] setup_requires=["requests", "numpy", @@ -25,7 +24,6 @@ setup(name='paddle', description='Parallel Distributed Deep Learning', install_requires=setup_requires, packages=packages, - package_data={'paddle.v2.master': ['libpaddle_master.so'], }, package_dir={ '': '${CMAKE_CURRENT_SOURCE_DIR}' },