diff --git a/CMakeLists.txt b/CMakeLists.txt index 6aa2e1715b92d73aa4e5e97d5e52ffac51451d80..7a7b5860a122a853fc9ce1da6494fc039b38bc10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_INFERENCE "Compile fluid inference library" ON) +option(ON_INFER "Turn on inference optimization." OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) @@ -179,6 +180,7 @@ include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 include(external/cares) include(external/cub) +include(external/xxhash) # download xxhash if (NOT WIN32) # there is no official support of snappystream, warpctc, nccl, cupti in windows @@ -301,3 +303,11 @@ if(WITH_DOC) find_python_module(recommonmark REQUIRED) add_subdirectory(doc) endif() + +if (ON_INFER) + message(STATUS "On inference mode, will take place some specific optimization.") + add_definitions(-DPADDLE_ON_INFERENCE) +else() + #TODO(luotao), combine this warning with `make inference_lib_dist` command. + message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.") +endif() diff --git a/README.md b/README.md index 8ee67f66423df8bce27f70015be8752457cd9784..56d6c10c642787836abb55cb2974bda0b8d22da4 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) -[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.0/getstarted/index_en.html) -[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) +[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) +[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) @@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. -### Latest PaddlePaddle Release: [Fluid 1.0.1](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0) +### Latest PaddlePaddle Release: [Fluid 1.1.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.1) ### Install Latest Stable Release: ``` # Linux CPU @@ -27,9 +27,9 @@ pip install paddlepaddle # Linux GPU cuda9cudnn7 pip install paddlepaddle-gpu # Linux GPU cuda8cudnn7 -pip install paddlepaddle-gpu==1.0.1.post87 +pip install paddlepaddle-gpu==1.1.0.post87 # Linux GPU cuda8cudnn5 -pip install paddlepaddle-gpu==1.0.1.post85 +pip install paddlepaddle-gpu==1.1.0.post85 # For installation on other platform, refer to http://paddlepaddle.org/ ``` @@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==1.0.1.post85 ## Installation -It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) on our website. +It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) on our website. ## Documentation -We provide [English](http://paddlepaddle.org/documentation/docs/en/1.0.0/getstarted/index_en.html) and -[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) documentation. +We provide [English](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) and +[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) documentation. - [Deep Learning 101](https://github.com/PaddlePaddle/book) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.0/user_guides/howto/training/cluster_howto.html) +- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.1/user_guides/howto/training/cluster_howto.html) You can run distributed training jobs on MPI clusters. -- [Python API](http://paddlepaddle.org/documentation/api/zh/1.0/fluid.html) +- [Python API](http://paddlepaddle.org/documentation/api/zh/1.1/fluid.html) Our new API enables much shorter programs. -- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.0/advanced_usage/development/contribute_to_paddle.html) +- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.1/advanced_usage/development/contribute_to_paddle.html) We appreciate your contributions! diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py index 9540900b112f54594bbfdbc8d7cd3b6e1f5269dd..ff616ddbb2cb1cb7f348d6d164815823b08b7629 100644 --- a/benchmark/fluid/args.py +++ b/benchmark/fluid/args.py @@ -142,5 +142,10 @@ def parse_args(): choices=['reduce', 'all_reduce'], default='all_reduce', help='Specify the reduce strategy, can be reduce, all_reduce') + parser.add_argument( + '--fuse_broadcast_op', + action='store_true', + help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.' + ) args = parser.parse_args() return args diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index ddd9fe809853a830ca676cc98f1819f683866def..5f3ce300acc44ad8d2898c27296b866c403f3cc8 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, else: build_strategy.reduce_strategy = fluid.BuildStrategy( ).ReduceStrategy.AllReduce + build_strategy.fuse_broadcast_op = args.fuse_broadcast_op avg_loss = train_args[0] @@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, if args.use_fake_data or args.use_reader_op: try: - fetch_ret = exe.run(fetch_list) except fluid.core.EOFException as eof: break diff --git a/cmake/configure.cmake b/cmake/configure.cmake index e9852f00b1835adec31373f58ac538f9685251e0..7f5771e561f6cc419fc9b3094174645ac432546e 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -50,11 +50,7 @@ if(NOT WITH_PROFILER) endif(NOT WITH_PROFILER) if(NOT CMAKE_CROSSCOMPILING) - if(WITH_AVX AND AVX512F_FOUND) - set(SIMD_FLAG ${AVX512F_FLAG}) - elseif(WITH_AVX AND AVX2_FOUND) - set(SIMD_FLAG ${AVX2_FLAG}) - elseif(WITH_AVX AND AVX_FOUND) + if(WITH_AVX AND AVX_FOUND) set(SIMD_FLAG ${AVX_FLAG}) elseif(SSE3_FOUND) set(SIMD_FLAG ${SSE3_FLAG}) diff --git a/cmake/external/xxhash.cmake b/cmake/external/xxhash.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c227e09719bd5f0e825f81fb96f78105aa10c79b --- /dev/null +++ b/cmake/external/xxhash.cmake @@ -0,0 +1,50 @@ +INCLUDE(ExternalProject) + +set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash) +set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash) +set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include") + +IF(WITH_STATIC_LIB) + SET(BUILD_CMD make lib) +ELSE() + IF(APPLE) + SET(BUILD_CMD sed -i \"\" "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) + ELSE(APPLE) + SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) + ENDIF(APPLE) +ENDIF() + +ExternalProject_Add( + extern_xxhash + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" + GIT_TAG "v0.6.5" + PREFIX ${XXHASH_SOURCE_DIR} + DOWNLOAD_NAME "xxhash" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + PATCH_COMMAND + BUILD_COMMAND ${BUILD_CMD} + INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install + TEST_COMMAND "" +) + +set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") +INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR}) + +add_library(xxhash STATIC IMPORTED GLOBAL) +set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES}) +include_directories(${XXHASH_INCLUDE_DIR}) +add_dependencies(xxhash extern_xxhash) + +LIST(APPEND external_project_dependencies xxhash) + +IF(WITH_C_API) + INSTALL(DIRECTORY ${XXHASH_INCLUDE_DIR} DESTINATION third_party/xxhash) + IF(ANDROID) + INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib/${ANDROID_ABI}) + ELSE() + INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib) + ENDIF() +ENDIF() diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 67cca09b64c1ed7a503a886e78347d786eae0de7..efdb093a7b28e19f3b2a774dd54f2e7f042e9ca7 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -31,7 +31,7 @@ function(copy TARGET) foreach(index RANGE ${len}) list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_DSTS ${index} dst) - add_custom_command(TARGET ${TARGET} PRE_BUILD + add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}" COMMAND cp -r "${src}" "${dst}" COMMENT "copying ${src} -> ${dst}") @@ -67,6 +67,13 @@ copy(boost_lib DEPS boost ) +set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/xxhash") +copy(xxhash_lib + SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES} + DSTS ${dst_dir} ${dst_dir}/lib + DEPS xxhash +) + if(NOT PROTOBUF_FOUND) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") copy(protobuf_lib @@ -186,7 +193,7 @@ copy(cmake_cache DSTS ${FLUID_INSTALL_DIR}) # This command generates a complete fluid library for both train and inference -add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) +add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) # Following commands generate a inference-only fluid library # third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR} diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 6653244507742b33d9524a7a0e4a5b2b575d358a..6b665a9effba4bef083d007c0c74f2f4c79e647e 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -24,6 +24,7 @@ if(NOT WITH_FLUID_ONLY) endif() add_subdirectory(testing) +set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") if(NOT MOBILE_INFERENCE AND NOT RPI AND NOT WITH_C_API) add_subdirectory(fluid) endif() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 587632ef435ed58dce1bfec141d7dd93e794810d..2b8b82e74fc49d454b5331460acbffd0e9404fb5 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -176,6 +176,7 @@ paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label' paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) +paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) @@ -354,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None) +paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None)) +paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) diff --git a/paddle/fluid/framework/attribute.cc b/paddle/fluid/framework/attribute.cc index 0dcecb62dba971b48c4f11c0ef47494be40eeea0..fabf2abfc803b8838edb48aa01ab8896799c97ac 100644 --- a/paddle/fluid/framework/attribute.cc +++ b/paddle/fluid/framework/attribute.cc @@ -64,6 +64,13 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { case proto::AttrType::LONG: { return attr_desc.l(); } + case proto::AttrType::LONGS: { + std::vector val(attr_desc.longs_size()); + for (int i = 0; i < attr_desc.longs_size(); ++i) { + val[i] = attr_desc.longs(i); + } + return val; + } default: PADDLE_THROW("Unsupport attr type %d", attr_desc.type()); } diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 14ca3e96209ed17f12e87fda8506806514698977..d9c76881b7e98d0b7cd29024b98c8f7720398c66 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -26,6 +26,113 @@ limitations under the License. */ namespace paddle { namespace framework { + +template +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + T* operator()(Attribute& attr) const { + T* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", + attr_name_, paddle::platform::demangle(typeid(T).name()), + paddle::platform::demangle(attr.type().name())); + } + return attr_value; + } + + const std::string& attr_name_; +}; + +// special handle bool +// FIXME(yuyang18): Currently we cast bool into int in python binding. It is +// hard to change the logic there. In another way, we should correct handle +// if the user set `some_flag=1`. +// +// FIX ME anytime if there is a better solution. +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + bool* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } else if (attr.type() == typeid(float)) { // NOLINT + float val = boost::get(attr); + attr = static_cast(val); + } + bool* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", + attr_name_, paddle::platform::demangle(attr.type().name())); + } + return attr_value; + } + + const std::string& attr_name_; +}; + +template <> +struct ExtractAttribute { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + int64_t* operator()(Attribute& attr) const { + if (attr.type() == typeid(int)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } else if (attr.type() == typeid(float)) { // NOLINT + int val = boost::get(attr); + attr = static_cast(val); + } + int64_t* attr_value = nullptr; + try { + attr_value = &boost::get(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", + attr_name_, paddle::platform::demangle(attr.type().name())); + } + return attr_value; + } + + const std::string& attr_name_; +}; + +template <> +struct ExtractAttribute> { + explicit ExtractAttribute(const std::string& attr_name) + : attr_name_(attr_name) {} + + std::vector* operator()(Attribute& attr) const { + if (attr.type() == typeid(std::vector)) { // NOLINT + std::vector val = boost::get>(attr); + std::vector vec(val.begin(), val.end()); + attr = vec; + } else if (attr.type() == typeid(std::vector)) { // NOLINT + std::vector val = boost::get>(attr); + std::vector vec(val.begin(), val.end()); + attr = vec; + } + std::vector* attr_value = nullptr; + try { + attr_value = &boost::get>(attr); + } catch (boost::bad_get& bad_get) { + PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", + attr_name_, paddle::platform::demangle(attr.type().name())); + } + return attr_value; + } + + const std::string& attr_name_; +}; + template inline proto::AttrType AttrTypeID() { Attribute tmp = T(); @@ -42,7 +149,11 @@ class AttrReader { inline const T& Get(const std::string& name) const { PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", name); - return boost::get(attrs_.at(name)); + + Attribute& attr = const_cast(attrs_.at(name)); + ExtractAttribute extract_attr(name); + T* attr_value = extract_attr(attr); + return *attr_value; } private: @@ -82,7 +193,7 @@ class DefaultValueSetter { public: explicit DefaultValueSetter(T default_value) : default_value_(default_value) {} - void operator()(T& value) const { value = default_value_; } + void operator()(T& value) const { value = default_value_; } // NOLINT private: T default_value_; @@ -117,84 +228,6 @@ class EnumInContainer { std::unordered_set container_; }; -template -struct ExtractAttribute { - explicit ExtractAttribute(const std::string& attr_name) - : attr_name_(attr_name) {} - - T* operator()(Attribute& attr) const { - T* attr_value = nullptr; - try { - attr_value = &boost::get(attr); - } catch (boost::bad_get& bad_get) { - PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s", - attr_name_, paddle::platform::demangle(typeid(T).name()), - paddle::platform::demangle(attr.type().name())); - } - return attr_value; - } - - const std::string& attr_name_; -}; - -// special handle bool -// FIXME(yuyang18): Currently we cast bool into int in python binding. It is -// hard to change the logic there. In another way, we should correct handle -// if the user set `some_flag=1`. -// -// FIX ME anytime if there is a better solution. -template <> -struct ExtractAttribute { - explicit ExtractAttribute(const std::string& attr_name) - : attr_name_(attr_name) {} - - bool* operator()(Attribute& attr) const { - if (attr.type() == typeid(int)) { // NOLINT - int val = boost::get(attr); - attr = static_cast(val); - } else if (attr.type() == typeid(float)) { // NOLINT - float val = boost::get(attr); - attr = static_cast(val); - } - bool* attr_value = nullptr; - try { - attr_value = &boost::get(attr); - } catch (boost::bad_get& bad_get) { - PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s", - attr_name_, paddle::platform::demangle(attr.type().name())); - } - return attr_value; - } - - const std::string& attr_name_; -}; - -template <> -struct ExtractAttribute { - explicit ExtractAttribute(const std::string& attr_name) - : attr_name_(attr_name) {} - - int64_t* operator()(Attribute& attr) const { - if (attr.type() == typeid(int)) { // NOLINT - int val = boost::get(attr); - attr = static_cast(val); - } else if (attr.type() == typeid(float)) { // NOLINT - int val = boost::get(attr); - attr = static_cast(val); - } - int64_t* attr_value = nullptr; - try { - attr_value = &boost::get(attr); - } catch (boost::bad_get& bad_get) { - PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s", - attr_name_, paddle::platform::demangle(attr.type().name())); - } - return attr_value; - } - - const std::string& attr_name_; -}; - // check whether a certain attribute fit its limits // an attribute can have more than one limits template @@ -235,7 +268,7 @@ class TypedAttrChecker { return *this; } - void operator()(AttributeMap& attr_map) const { + void operator()(AttributeMap& attr_map) const { // NOLINT if (!attr_map.count(attr_name_)) { // user do not set this attr PADDLE_ENFORCE(!default_value_setter_.empty(), @@ -271,7 +304,7 @@ class OpAttrChecker { return *(checker.target>()); } - void Check(AttributeMap& attr_map) const { + void Check(AttributeMap& attr_map) const { // NOLINT for (const auto& checker : attr_checkers_) { checker(attr_map); } diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index e0a3ef5a9c6c53c42ebea1a41cac0d18a77781b2..17188ac5f301102ae79c6ace676b84ee66e28801 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -16,12 +16,14 @@ if(WITH_GPU) dynload_cuda variable_visitor) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) + nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) + cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) endif() cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) @@ -34,7 +36,7 @@ if(WITH_GPU) endif() cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) if(WITH_GPU) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) @@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo cc_library(build_strategy SRCS build_strategy.cc DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass - fuse_elewise_add_act_pass) + fuse_elewise_add_act_pass multi_batch_merge_pass) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 4fdab5cd94358d08eac7f8b041bf16d09042f0bd..7f0d06c892541a2697a4ed083f6f4c0fc774a2a4 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -48,16 +48,27 @@ void BroadcastOpHandle::RunImpl() { var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); } + BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes); +} + +void BroadcastOpHandle::BroadcastOneVar( + const VarHandle &in_var_handle, + const std::vector &out_var_handles, + const std::vector &var_scopes) { auto *in_var = - var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); + var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); PADDLE_ENFORCE_NOT_NULL(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); + if (UNLIKELY(!in_tensor.IsInitialized())) { + VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!"; + return; + } - InitOutputValue(*in_var_handle, out_var_handles); + InitOutputValue(in_var_handle, out_var_handles); if (platform::is_cpu_place(in_tensor.place())) { for (auto *out_var_handle : out_var_handles) { - if (out_var_handle->IsTheSameVar(*in_var_handle)) { + if (out_var_handle->IsTheSameVar(in_var_handle)) { continue; } auto &out_p = out_var_handle->place_; @@ -114,12 +125,12 @@ void BroadcastOpHandle::RunImpl() { } } - if (!out_handle->IsTheSameVar(*in_var_handle)) { - auto out_var = var_scopes.at(in_var_handle->scope_idx_) + if (!out_handle->IsTheSameVar(in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle.scope_idx_) ->FindVar(out_var_handles[0]->name_); paddle::framework::TensorCopy( - in_tensor, in_var_handle->place_, - *(dev_ctxes_.at(in_var_handle->place_)), + in_tensor, in_var_handle.place_, + *(dev_ctxes_.at(in_var_handle.place_)), &VariableVisitor::GetMutableTensor(out_var)); } }); diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index fe4e733e43417977df324fde808f52b228a27d19..020d351e891c7afab37c59c0ff8d8e5e7ba184f2 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - private: + void BroadcastOneVar(const VarHandle &in_var_handle, + const std::vector &out_var_handles, + const std::vector &var_scopes); + std::vector local_scopes_; std::vector places_; #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 6a6b497fa897e3882995688bf36704b1d77ea962..fefd27fc86fb8dce3311fa580d90f518906dd862 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -121,6 +121,7 @@ std::unique_ptr BuildStrategy::Apply( USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); +USE_PASS(multi_batch_merge_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 02c4bea16916d58a6d0fce8918f8fceb9ff9356e..f3ffaf6ecd7c4dd99c40fe58ba88c0cbdc14bde7 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -69,6 +69,8 @@ struct BuildStrategy { bool enable_data_balance_{false}; + bool fuse_broadcast_op_{false}; + // User normally doesn't need to call this API. // The PassBuilder allows for more customized insert, remove of passes // from python side. diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..51dfa2d0711f49aaefab0af3549283dbf77eee4a --- /dev/null +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +namespace details { + +void FusedBroadcastOpHandle::RunImpl() { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + + if (places_.size() == 1UL) return; + + auto in_var_handles = DynamicCast(inputs_); + auto out_var_handles = DynamicCast(outputs_); + + WaitInputVarGenerated(); + + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + + size_t place_num = places_.size(); + PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size()); + + for (size_t i = 0; i < in_var_handles.size(); ++i) { + BroadcastOneVar( + *in_var_handles[i], + std::vector(out_var_handles.begin() + i * place_num, + out_var_handles.begin() + (i + 1) * place_num), + var_scopes); + } +} + +std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; } + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.h b/paddle/fluid/framework/details/fused_broadcast_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..e37259526a5f6f57d51a0ca8bca96a18211a4790 --- /dev/null +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.h @@ -0,0 +1,57 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/device_context.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +struct FusedBroadcastOpHandle : public BroadcastOpHandle { + public: +#ifdef PADDLE_WITH_CUDA + FusedBroadcastOpHandle(ir::Node *node, + const std::vector local_scopes, + const std::vector &places, + const platform::NCCLContextMap *nccl_ctx) + : BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {} +#else + FusedBroadcastOpHandle(ir::Node* node, const std::vector local_scopes, + const std::vector& places) + : BroadcastOpHandle(node, local_scopes, places) {} +#endif + std::string Name() const override; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index ebd1d644bcce0554904ee0931827ef43a6d52b11..f3819887a196a7c8bf35897467bb9d68b428094e 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" +#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -347,7 +348,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( BuildStrategy::GradientScaleStrategy::kCustomized) { // TODO(paddle-dev): Why is there no input for this op_handle? auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; - CreateScaleLossGradOp(&result, loss_grad_name); + CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]); } // This assumes the backward generating code will ensure IsScaleLossOp // is true only for the op that scale the final scalar loss. @@ -436,10 +437,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( if ((use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || is_dist_train) { - for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set[dev_id]; - for (auto &bcast_name : to_bcast_set) { - CreateBroadcastOp(&result, bcast_name, dev_id); + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(&result, bcast_var_name_set); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(&result, bcast_name, dev_id); + } } } } @@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, } } +void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( + ir::Graph *result, + const std::vector> &bcast_varnames) const { +#ifdef PADDLE_WITH_CUDA + auto *op_handle = new FusedBroadcastOpHandle( + result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_); +#else + auto *op_handle = new FusedBroadcastOpHandle( + result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), + local_scopes_, places_); +#endif + result->Get(kGraphOps).emplace_back(op_handle); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + SetCommunicationContext(op_handle, p); + } + + for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { + for (auto &p_name : bcast_varnames[dev_id]) { + auto *in = + result->Get(kGraphVars).at(dev_id).at(p_name).back().get(); + op_handle->AddInput(in); + for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { + auto &p = places_[out_dev_id]; + auto &vars = + result->Get(kGraphVars).at(out_dev_id).at(p_name); + auto *out_var = new VarHandle( + result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), + vars.size(), out_dev_id, p_name, p); + vars.emplace_back(out_var); + op_handle->AddOutput(out_var); + } + } + } +} + void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { @@ -602,7 +645,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( - ir::Graph *result, const std::string &loss_grad_name) const { + ir::Graph *result, const std::string &loss_grad_name, + ir::Node *out_var_node) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); @@ -617,10 +661,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput( - result, op_handle, - result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable), - places_[i], i); + CreateOpOutput(result, op_handle, + result->CreateVarNode(out_var_node->Var()), places_[i], i); } } @@ -680,7 +722,8 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, } if (node->Op()->Type() == "split_byref" || - node->Op()->Type() == "split_selected_rows") { + node->Op()->Type() == "split_selected_rows" || + node->Op()->Type() == "split_ids") { // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index cdf9f13cde608b546d17a1e53e0f6acea9e12566..03b2de2f04da4bac8d342a76c80fd12beaeba4b7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -61,7 +61,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t num_places) const; void CreateScaleLossGradOp(ir::Graph *result, - const std::string &loss_grad_name) const; + const std::string &loss_grad_name, + ir::Node *out_var_node) const; VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; @@ -78,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass { void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; + void CreateFusedBroadcastOp( + ir::Graph *result, + const std::vector> &bcast_varnames) const; + bool IsSparseGradient(const std::string &og) const; size_t GetAppropriateDeviceID( diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index c99406799ba5f664c4b9f80e0567b293e4ffea51..efdabffb9b33ddf007c13008d0f3afb7a3961eda 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -35,6 +35,7 @@ enum AttrType { BLOCK = 8; LONG = 9; BLOCKS = 10; + LONGS = 11; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -55,6 +56,7 @@ message OpDesc { optional int32 block_idx = 12; optional int64 l = 13; repeated int32 blocks_idx = 14; + repeated int64 longs = 15; }; message Var { diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a145b2fafe64f8c80ac7808583d6670ca0218c06..ce006b7a3fbc16f3c9149933390969b14a46b484 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -36,6 +36,7 @@ pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) +pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) if(WITH_MKLDNN) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 398f7095968e62f92d610f560d7574b27706d13e..87926156e410b46ffccb90e1ab0ec553af289dd3 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -89,14 +89,20 @@ bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, Graph::Graph(const ProgramDesc &program) : program_(program) { // Make the nodes id start from 0. Node::ResetId(); + auto var_nodes = InitFromProgram(program_); + ResolveHazard(var_nodes); +} +std::map> Graph::InitFromProgram( + const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); std::unordered_map all_vars; + // var nodes for each var name, will have multiple versions in SSA + std::map> var_nodes; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } - std::map> var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); // For input args, reuse the same var name if it was created before. @@ -134,7 +140,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var->inputs.push_back(node); } } + return std::move(var_nodes); +} +void Graph::ResolveHazard( + const std::map> &var_nodes) { /** * We should handle write after read(WAR) and write after write(WAW) here. * Because some of the operators of the program can be executed parallelly. @@ -153,6 +163,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { auto it_old = versions.rbegin(); ++it_old; for (; it_old != versions.rend(); it_new = it_old, ++it_old) { + VLOG(3) << "deal with var: " << (*it_new)->Name(); ir::Node *write_op = (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; const auto &read_ops = (*it_old)->outputs; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ab687e760a761d4e445726bd5149966adc2403d0..9d7aa5d32deb274fbf29481b0d4754c05d1e21b5 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -160,6 +160,12 @@ class Graph { return nullptr; } + std::map> InitFromProgram( + const ProgramDesc &program); + + void ResolveHazard( + const std::map> &var_nodes); + private: // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index c54766d95a61ac1a4b61566c6de62cbc86685a1d..01e878089171e4620f32b57a65d92d1c86d307db 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) { std::deque q_nodes; std::vector> graph_nodes; std::unordered_set g_nodes; + // q_set used to record records in the queue. + std::unordered_set q_set; size_t graph_count = 0; - auto traverse_nodes = [&visited_nodes, - &q_nodes](const std::vector &nodes) { - std::copy_if( - nodes.begin(), nodes.end(), std::back_inserter(q_nodes), - [&visited_nodes](Node *node) { return !visited_nodes.count(node); }); + auto traverse_nodes = [&visited_nodes, &q_nodes, + &q_set](const std::vector &nodes) { + for (auto n : nodes) { + if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) { + q_nodes.push_back(n); + q_set.insert(n); + } + } }; while (visited_nodes.size() != nodes.size()) { if (!q_nodes.empty()) { auto cur_node = q_nodes.front(); q_nodes.pop_front(); + q_set.erase(cur_node); visited_nodes.insert(cur_node); g_nodes.insert(cur_node); traverse_nodes(cur_node->inputs); @@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) { for (auto &n : nodes) { if (visited_nodes.count(n) == 0) { q_nodes.push_back(n); + q_set.insert(n); break; } } diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd5b76426eb55cebdabfccd700439a4c418a10f0 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -0,0 +1,315 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/multi_batch_merge_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +static const char kNumRepeats[] = "num_repeats"; +typedef std::unordered_map> SSAVarList; + +ir::Node* SameNameVar(std::unordered_set all, ir::Node* target) { + for (auto n : all) { + if (target->IsVar() && target->Name() == n->Name()) { + return n; + } + } + return nullptr; +} + +VarDesc CopyVarDesc(VarDesc* var_desc) { + VarDesc repeated_var(var_desc->Name()); + // copy other variable attributes + if (var_desc->GetType() != proto::VarType::READER) { + repeated_var.SetType(var_desc->GetType()); + repeated_var.SetShape(var_desc->GetShape()); + repeated_var.SetDataType(var_desc->GetDataType()); + repeated_var.SetLoDLevel(var_desc->GetLoDLevel()); + repeated_var.SetPersistable(var_desc->Persistable()); + } else { + // TODO(typhoonzero): copy reader var + } + return repeated_var; +} + +VarDesc UpdateGradVarDesc( + VarDesc* var_desc, int repeat, + const std::unordered_set& grad_names, + const std::unordered_set& bn_vars_need_rename) { + if (grad_names.find(var_desc->Name()) != grad_names.end() || + bn_vars_need_rename.find(var_desc->Name()) != bn_vars_need_rename.end()) { + std::string new_gname = + string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat); + VarDesc repeated_var = CopyVarDesc(var_desc); + repeated_var.SetName(new_gname); + VLOG(3) << "update " << var_desc->Name() << " to repeat " << repeat; + return repeated_var; + } + return *var_desc; +} + +std::unique_ptr BatchMergePass::ApplyImpl( + std::unique_ptr graph) const { + int num_repeats = Get(kNumRepeats); + std::vector forward_backward_ops; + std::vector optimize_ops; + std::vector lr_ops; // ops other than forward/backward/optimize + std::unordered_set grad_names; + + std::vector nodes = TopologySortOperations(*graph); + auto origin_nodes = graph->ReleaseNodes(); + VLOG(3) << "origin nodes count: " << origin_nodes.size(); + ir::Graph& result = *graph; + + // 1. record op nodes of different roles + for (auto node : nodes) { + if (node->IsVar()) continue; + int op_role = boost::get(node->Op()->GetAttr( + framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + if ((op_role == static_cast(framework::OpRole::kForward)) || + (op_role & static_cast(framework::OpRole::kBackward)) || + (op_role & static_cast(framework::OpRole::kLoss))) { + forward_backward_ops.push_back(node); + } else if ((op_role & static_cast(framework::OpRole::kOptimize)) || + (op_role & static_cast(framework::OpRole::kDist)) || + (op_role & static_cast(framework::OpRole::kRPC))) { + optimize_ops.push_back(node); + auto op_role_var = node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName()); + auto op_role_vars = boost::get>(op_role_var); + for (size_t i = 0; i < op_role_vars.size(); i += 2) { + grad_names.insert(op_role_vars[i + 1]); + } + } else if (op_role & static_cast(framework::OpRole::kLRSched)) { + lr_ops.push_back(node); + } else { // NOLINT + PADDLE_THROW("Invalid op_role: %d", static_cast(op_role)); + } + } + + // 2. copy forward backward + ir::Node* prev_repeat_last_op_node = nullptr; + // record origin_grad -> repeated grad list map. + std::map> grad_repeated_map; + std::map> created; + std::unordered_set bn_vars_need_rename; + for (int i = 0; i < num_repeats; ++i) { + std::unordered_set copied; + for (size_t node_idx = 0; node_idx < forward_backward_ops.size(); + ++node_idx) { + auto node = forward_backward_ops[node_idx]; + OpDesc repeated_op(*(node->Op()), node->Op()->Block()); + // 3. rename grad outputs to current repeat. + for (auto outname : repeated_op.OutputArgumentNames()) { + if (grad_names.find(outname) != grad_names.end()) { + std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i); + repeated_op.RenameOutput(outname, new_gname); + } + } + // 3.5 let batch_norm ops use independent vars, note batch_norm_grad do + // not need this update + if (node->Name() == "batch_norm") { + // NOTE: assume bn op created by layers use save var as output mean and + // variance + std::string new_mean_name = + string::Sprintf("%s.repeat.%d", repeated_op.Input("Mean")[0], i); + std::string new_var_name = string::Sprintf( + "%s.repeat.%d", repeated_op.Input("Variance")[0], i); + bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]); + bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]); + VLOG(3) << "renaming " << repeated_op.Input("Mean")[0] << " to " + << new_mean_name; + repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name); + repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name); + repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0], + new_mean_name); + repeated_op.RenameOutput(repeated_op.Output("VarianceOut")[0], + new_var_name); + } + + // 3.9 do copy + auto repeated_node = result.CreateOpNode(&repeated_op); + copied.insert(node); + + // 4. add deps between repeats + if (node_idx == forward_backward_ops.size() - 1) { + prev_repeat_last_op_node = repeated_node; + } + if (node_idx == 0 && prev_repeat_last_op_node) { + auto* depvar = result.CreateControlDepVar(); + prev_repeat_last_op_node->outputs.push_back(depvar); + depvar->inputs.push_back(prev_repeat_last_op_node); + repeated_node->inputs.push_back(depvar); + depvar->outputs.push_back(repeated_node); + } + + for (auto in_node : node->inputs) { + if (in_node->IsCtrlVar()) { + continue; + } + ir::Node* var = nullptr; + auto updated_var = UpdateGradVarDesc(in_node->Var(), i, grad_names, + bn_vars_need_rename); + // should be initialized by startup, how to initilize tensor in the + // scope? + if (node->Name() == "batch_norm" && + bn_vars_need_rename.find(in_node->Name()) != + bn_vars_need_rename.end()) { + // Create bn mean/variance for each repeat + var = result.CreateVarNode(&updated_var); + created[updated_var.Name()].push_back(var); + copied.insert(in_node); + repeated_node->inputs.push_back(var); + var->outputs.push_back(repeated_node); + continue; + } + + // for other ops + if (in_node->inputs.empty() && i > 0) { + // do not copy head vars (inputs, params) in repeats > 0 + var = created.at(in_node->Name()).back(); + } else { + if (copied.find(in_node) == copied.end()) { + var = result.CreateVarNode(&updated_var); + if (grad_names.find(in_node->Var()->Name()) != grad_names.end()) { + grad_repeated_map[in_node].push_back(var); + } + copied.insert(in_node); + created[updated_var.Name()].push_back(var); + } else { + var = created.at(updated_var.Name()).back(); + } + } + repeated_node->inputs.push_back(var); + var->outputs.push_back(repeated_node); + } + for (auto out_node : node->outputs) { + if (out_node->IsCtrlVar()) { + continue; + } + ir::Node* var = nullptr; + auto updated_var = UpdateGradVarDesc(out_node->Var(), i, grad_names, + bn_vars_need_rename); + if (copied.find(out_node) == copied.end()) { + var = result.CreateVarNode(&updated_var); + if (grad_names.find(out_node->Var()->Name()) != grad_names.end()) { + grad_repeated_map[out_node].push_back(var); + } + copied.insert(out_node); + created[updated_var.Name()].push_back(var); + } else { + var = created.at(updated_var.Name()).back(); + } + repeated_node->outputs.push_back(var); + var->inputs.push_back(repeated_node); + } + } + } + + // 5. create GRAD merge op node + for (auto kv : grad_repeated_map) { + OpDesc sum_op; + sum_op.SetType("sum"); + std::vector repeated_grad_names; + for (auto r : kv.second) { + repeated_grad_names.push_back(r->Var()->Name()); + } + sum_op.SetInput("X", repeated_grad_names); + sum_op.SetOutput("Out", {kv.first->Var()->Name()}); + sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kBackward)); + auto sum_op_node = result.CreateOpNode(&sum_op); + for (auto r : kv.second) { + sum_op_node->inputs.push_back(r); + r->outputs.push_back(sum_op_node); + } + auto sum_out_var_node = result.CreateVarNode(kv.first->Var()); + sum_op_node->outputs.push_back(sum_out_var_node); + sum_out_var_node->inputs.push_back(sum_op_node); + created[sum_out_var_node->Name()].push_back(sum_out_var_node); + + OpDesc scale_op; + scale_op.SetType("scale"); + scale_op.SetInput("X", {sum_out_var_node->Var()->Name()}); + // NOTE: inplace scale. + scale_op.SetOutput("Out", {sum_out_var_node->Var()->Name()}); + scale_op.SetAttr("scale", static_cast(1.0f / num_repeats)); + scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kBackward)); + auto scale_op_node = result.CreateOpNode(&scale_op); + scale_op_node->inputs.push_back(sum_out_var_node); + sum_out_var_node->outputs.push_back(scale_op_node); + auto scale_out_var_node = result.CreateVarNode(sum_out_var_node->Var()); + scale_op_node->outputs.push_back(scale_out_var_node); + scale_out_var_node->inputs.push_back(scale_op_node); + created[scale_out_var_node->Name()].push_back(scale_out_var_node); + } + // 6. add optimize ops + { + auto copy_node = [&result, &created](ir::Node* node) { + auto op_node = result.CreateOpNode(node->Op()); + // copy op ins/outs + // NOTE: for send/recv ops, the OpDesc uses ctrldepvar to describe + // dependencies, so create those depvars if OpDesc have in/outs. + for (auto in_node : node->inputs) { + if (in_node->IsCtrlVar() && !in_node->Var()) { + continue; + } + ir::Node* var = nullptr; + if (created.find(in_node->Name()) == created.end()) { + var = result.CreateVarNode(in_node->Var()); + created[in_node->Name()].push_back(var); + } else { + var = created.at(in_node->Name()).back(); + } + op_node->inputs.push_back(var); + var->outputs.push_back(op_node); + } + for (auto out_node : node->outputs) { + if (out_node->IsCtrlVar() && !out_node->Var()) { + continue; + } + auto var = result.CreateVarNode(out_node->Var()); + created[out_node->Name()].push_back(var); + op_node->outputs.push_back(var); + var->inputs.push_back(op_node); + } + }; + for (auto node : lr_ops) { + copy_node(node); + } + for (auto node : optimize_ops) { + copy_node(node); + } + } + + result.ResolveHazard(created); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_batch_merge_pass, paddle::framework::ir::BatchMergePass) + .RequirePassAttr(paddle::framework::ir::kNumRepeats); diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.h b/paddle/fluid/framework/ir/multi_batch_merge_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c1e5aef20dbc60c18ed03038818bfd8ab217bf28 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +// BatchMergePass is used to copy forward and backward ops for several +// times to run several batches to simulate large batch size training +// as if we have more than 1 GPUs. +// User can define how many batches to run, gradients will be merged +// through those repeats, and then do optimization using merged gradients. +// This pass is extremely useful when doing large batch-size distributed +// sync training, we can simulate even large batch size as if we have more +// GPUs. + +class BatchMergePass : public Pass { + public: + virtual ~BatchMergePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 1e7da9a69c7cbf8c13306656599a759515802b76..669d08c70c9b7453264806b346a6c9eb211cfd4a 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -418,7 +418,7 @@ void LoDTensor::MergeLoDTensor( PADDLE_ENFORCE_EQ(new_lod.size(), lod.size()); for (size_t j = 0; j < lod.size(); ++j) { auto &sub_lod = new_lod[j]; - auto &offset = sub_lod.back(); + size_t offset = sub_lod.back(); for (size_t k = 1; k < lod[j].size(); ++k) { sub_lod.push_back(lod[j][k] + offset); } diff --git a/paddle/fluid/framework/lod_tensor_array.h b/paddle/fluid/framework/lod_tensor_array.h index 6d7b6a4ada8729e3698dab5d2b1861aac632be79..36a5c3c5d601390beedaf37ceb98ee2c63ecf5a6 100644 --- a/paddle/fluid/framework/lod_tensor_array.h +++ b/paddle/fluid/framework/lod_tensor_array.h @@ -18,6 +18,8 @@ limitations under the License. */ namespace paddle { namespace framework { + using LoDTensorArray = std::vector; -} + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index c293cf92b4f3d530109c76850df184af9cad7399..8ece618f3f72552fedcffab3e03ebb30476b7cab 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -419,8 +419,15 @@ struct SetAttrDescVisitor : public boost::static_visitor { } VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx()); } + void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); } + void operator()(int64_t v) const { attr_->set_l(v); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_longs()); + } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } }; diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 5527783faad6f7ea6e0b9e776cbde3c743277f16..4c59c73d8779eceb267ad532aabccabbd54b0df2 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -28,12 +28,12 @@ enum class OpRole { kBackward = 0x0001, kOptimize = 0x0002, // RPC role is for send/recv releated op - kRPC = 0x0003, + kRPC = 0x0004, // Dist role is for split_byref/split_selected_rows/concat // used for distributed training. - kDist = 0x0004, + kDist = 0x0008, // Tag all learning rate scheduler operators. - kLRSched = 0x0005, + kLRSched = 0x0010, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 14fcde2fe3b1c3acfc0994e9cd37a784c57826d7..d8251c7255320f89bca15f860b1b7559aa54ffd6 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -358,11 +358,11 @@ static bool VarIsTensor(const Variable* var) { return var->IsType() || var->IsType(); } -static const Tensor* GetTensorFromVar(Variable* var) { +const Tensor* GetTensorFromVar(const Variable* var) { if (var->IsType()) { - return var->GetMutable(); + return static_cast(&(var->Get())); } else if (var->IsType()) { - return var->GetMutable()->mutable_value(); + return &(var->Get().value()); } else { PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", var->Type().name()); @@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { template <> const Tensor* ExecutionContext::Input(const std::string& name) const { auto* var = InputVar(name); - return var == nullptr ? nullptr - : GetTensorFromVar(const_cast(var)); + return var == nullptr ? nullptr : GetTensorFromVar(var); } template <> diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 626b50edfd39424473be33e9f8baec5970471477..be3f06360d66453ae18fdc9abbf6ea4b29491248 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -63,6 +63,7 @@ inline std::string GradVarName(const std::string& var_name) { } proto::VarType::Type GetDataTypeOfVar(const Variable* var); +const Tensor* GetTensorFromVar(const Variable* var); class OperatorBase; class ExecutionContext; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 7dad872dd04754d2866c361b74ab1236ff743da5..662a29d41937521810473cf29fe9494b1b4a9d9e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -109,18 +109,9 @@ ParallelExecutor::ParallelExecutor( if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { BCastParamsToDevices(bcast_vars); } - // Startup Program has been run. All local scopes has correct parameters. +// Startup Program has been run. All local scopes has correct parameters. - // Step 2. Create vars in each scope; - std::vector var_infos; - for (auto *var : main_program.Block(0).AllVars()) { - var_infos.emplace_back(); - var_infos.back().name_ = var->Name(); - var_infos.back().type_ = var->GetType(); - var_infos.back().persistable_ = var->Persistable(); - } - -// Step 3. Convert main_program to SSA form and dependency graph. Also, insert +// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #ifdef PADDLE_WITH_CUDA std::unique_ptr graph = build_strategy.Apply( @@ -156,6 +147,26 @@ ParallelExecutor::ParallelExecutor( params, member_->local_scopes_, member_->use_cuda_); #endif + // Step 3. Create vars in each scope. Passes may also create new vars. + // skip control vars and empty vars + std::vector var_infos; + for (auto &node : graph->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + var_infos.emplace_back(); + var_infos.back().name_ = node->Var()->Name(); + var_infos.back().type_ = node->Var()->GetType(); + var_infos.back().persistable_ = node->Var()->Persistable(); + } + } + + if (VLOG_IS_ON(5)) { + // If the loss_var_name is given, the number of graph should be only one. + if (loss_var_name.size()) { + PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, + "The number of graph should be only one"); + } + } + if (exec_strategy.type_ == ExecutionStrategy::kDefault) { member_->executor_.reset(new details::ThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, places, std::move(graph))); @@ -179,6 +190,10 @@ void ParallelExecutor::BCastParamsToDevices( } auto &main_tensor = main_var->Get(); + if (!main_tensor.IsInitialized()) { + VLOG(3) << "one in var not inited, return!"; + continue; + } auto &dims = main_tensor.dims(); if (paddle::platform::is_gpu_place(main_tensor.place())) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 14f9f36812d690fc4a7440f2e7e6a85e9993a535..9462620e829ec815e1553f6378a67463ea3b8aa3 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -78,6 +78,8 @@ class Scope { /// Drop all kids scopes belonged to this scope. void DropKids(); + std::list& kids() const { return kids_; } + /// Find if a scope exists in the kid scopes bool HasKid(const Scope* scope) const; diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index e099e40f121ff13657e563eb608feecbca0551be..2de6233a9e0d320ec9a06d547db3575eb61925c0 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -36,7 +36,7 @@ using Attribute = boost::variant, std::vector, std::vector, bool, std::vector, BlockDesc*, int64_t, - std::vector>; + std::vector, std::vector>; using AttributeMap = std::unordered_map; diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 9794a193bcfaae19552b1f6fbdf2dab2898033d5..dbbe8bcba69a1d87e21c8eae18834fb708e8b1e4 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -30,7 +30,7 @@ if (WITH_GPU AND TENSORRT_FOUND) endif() # Create static library -cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor) +cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array) if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. @@ -40,7 +40,7 @@ endif() # Create shared library cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} - DEPS ${fluid_modules} paddle_fluid_api) + DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) if(NOT APPLE) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 2e79d495d5ff00000000029ac0f6eb486aaea94a..ef4142f334e503380dc7ccd74c348404ffe52ee6 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) { passes.push_back("mkldnn_placement_pass"); } #endif + // infer_clean_graph_pass should be the first default pass + // after mkldnn_placement_pass. + passes.push_back("infer_clean_graph_pass"); for (auto& pass : ir_passes_) { if (!disabled_ir_passes_.count(pass)) { passes.push_back(pass); diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index c51a4fdb2f6b27e54637481c23bf6f1f6ec97718..7114f5222c5904bb4422b9e67ad035b85bbb770c 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // "attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // "embedding_fc_lstm_fuse_pass", // diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 0ddd5d53f836131fe37d412fc867cb38f11ee2b5..a55426f74f988176aeb180e48d1af8632ed3b5c7 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -18,7 +18,8 @@ if(APPLE) endif(APPLE) -set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB}) +set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB} + ) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine analysis_predictor) @@ -31,10 +32,17 @@ function(inference_api_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cc_test(${TARGET_NAME} - SRCS ${inference_test_SRC} - DEPS "${inference_deps}" - ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + if (WITH_GPU) + cc_test(${TARGET_NAME} + SRCS ${inference_test_SRC} + DEPS "${inference_deps}" + ARGS --dirname=${PYTHON_TESTS_DIR}/book/ --fraction_of_gpu_memory_to_use=0.15) + else() + cc_test(${TARGET_NAME} + SRCS ${inference_test_SRC} + DEPS "${inference_deps}" + ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + endif() if(inference_test_ARGS) set_tests_properties(${TARGET_NAME} PROPERTIES DEPENDS "${inference_test_ARGS}") @@ -42,7 +50,8 @@ function(inference_api_test TARGET_NAME) endif(WITH_TESTING) endfunction(inference_api_test) -cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope) +cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) +cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api) @@ -52,8 +61,6 @@ cc_test(test_paddle_inference_api inference_api_test(test_api_impl SRC api_impl_tester.cc ARGS test_word2vec test_image_classification) - -set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) cc_test(test_analysis_predictor SRCS analysis_predictor_tester.cc DEPS analysis_predictor ${inference_deps} paddle_inference_api ARGS --dirname=${PYTHON_TESTS_DIR}/book) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index eec665767164dc6e79738890947c54d7f7217037..54c37fe64590aa82d7100c93c4c5c4ee36491421 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -82,6 +82,7 @@ bool AnalysisPredictor::Init( // Get the feed_target_names and fetch_target_names PrepareFeedFetch(); + return true; } @@ -109,6 +110,10 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } VLOG(3) << "predict cost: " << timer.toc() << "ms"; + + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } @@ -322,6 +327,9 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( bool AnalysisPredictor::ZeroCopyRun() { executor_->Run(); + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 5a9f4d36959d4ee7ca16dec769d9d1283b8787cb..b7dc2067332278c1c38df4beefb5059efe76417f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/string/printf.h" @@ -88,6 +89,7 @@ class AnalysisPredictor : public PaddlePredictor { // Memory buffer for feed inputs. The temporary LoDTensor will cause serious // concurrency problems, so cache them. std::vector feed_tensors_; + details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 7cda9c5d8a8366bd097491f37f5352a10e4fb16c..d06ab8f8c8e3c0adf4a4074eb1450012126e83ea 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -157,6 +158,10 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, return false; } VLOG(3) << "predict cost: " << timer.toc() << "ms"; + + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index 7882f6a53c7ce9a2486158ea9b50c018d1814091..4e4ab47ca9c5e37f2714ebd48d250c23c7e9b117 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -26,11 +26,11 @@ limitations under the License. */ #include #include -#include "paddle/fluid/inference/api/paddle_inference_api.h" - #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/init.h" @@ -77,6 +77,7 @@ class NativePaddlePredictor : public PaddlePredictor { std::vector fetchs_; // Do not use unique_ptr, use parent scope to delete framework::Scope *sub_scope_{nullptr}; + details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 03f0f726eb61c2619c7719a865383090f86b5b7f..49683eab07a2f5bc008272038a27bdb277396284 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -52,6 +52,7 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") if (NOT WIN32) include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") @@ -61,8 +62,8 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") -if (NOT WIN32) - if (USE_TENSORRT AND WITH_GPU) +if (NOT WIN32) + if (USE_TENSORRT AND WITH_GPU) include_directories("${TENSORRT_INCLUDE_DIR}") link_directories("${TENSORRT_LIB_DIR}") endif() @@ -77,13 +78,14 @@ endif(NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/paddle/lib") add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") - set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") if(EXISTS ${MKLDNN_PATH}) @@ -107,7 +109,7 @@ if (NOT WIN32) set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf snappystream snappy z + glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) else() set(DEPS ${DEPS} @@ -120,7 +122,7 @@ endif(NOT WIN32) if(WITH_GPU) if(NOT WIN32) - if (USE_TENSORRT) + if (USE_TENSORRT) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 6e682b69583e00ab1bbe1c0d22e21ae114a61a76..1ac655bdbbf7c45bfdde2c5fa8026fab2c891903 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -16,7 +16,7 @@ if [ $2 == ON ]; then fi if [ $3 == ON ]; then use_gpu_list='true false' -else +else use_gpu_list='false' fi @@ -83,7 +83,7 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_STATIC_LIB=$WITH_STATIC_LIB make -j for use_gpu in $use_gpu_list; do - for vis_demo_name in $vis_demo_list; do + for vis_demo_name in $vis_demo_list; do ./vis_demo \ --modeldir=$DATA_DIR/$vis_demo_name/model \ --data=$DATA_DIR/$vis_demo_name/data.txt \ @@ -95,7 +95,7 @@ for WITH_STATIC_LIB in ON OFF; do fi done done - + # --------tensorrt mobilenet------ if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then rm -rf * @@ -107,7 +107,7 @@ for WITH_STATIC_LIB in ON OFF; do -DUSE_TENSORRT=$USE_TENSORRT \ -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR - make -j + make -j ./trt_mobilenet_demo \ --modeldir=$DATA_DIR/mobilenet/model \ --data=$DATA_DIR/mobilenet/data.txt \ diff --git a/paddle/fluid/inference/api/details/reset_tensor_array.cc b/paddle/fluid/inference/api/details/reset_tensor_array.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ae6c6dc9f44650c1c62f5be5448864d817513b1 --- /dev/null +++ b/paddle/fluid/inference/api/details/reset_tensor_array.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" + +namespace paddle { +namespace details { + +// Should be called after the parameters are loaded. +void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) { + if (flag_) { + for (auto &var_name : scope->LocalVarNames()) { + auto *var = scope->FindVar(var_name); + // TODO(Superjomn) should avoid the case when a TensorArray is a + // parameter. + if (var_name == "feed" || var_name == "fetch") continue; + if (var->Type() == typeid(framework::LoDTensorArray)) { + VLOG(4) << "collect " << var_name; + arrays_.push_back(var->GetMutable()); + } + } + for (auto *kid : scope->kids()) { + CollectTensorArrays(kid); + } + + VLOG(3) << "Collect " << arrays_.size() << " arrays"; + flag_ = false; + } +} + +// Should be called when `Run` finished. +void TensorArrayBatchCleaner::ResetTensorArray() { + for (auto *arr : arrays_) { + arr->clear(); + } +} + +} // namespace details +} // namespace paddle diff --git a/paddle/fluid/inference/api/details/reset_tensor_array.h b/paddle/fluid/inference/api/details/reset_tensor_array.h new file mode 100644 index 0000000000000000000000000000000000000000..a39449ff0e67786815dfb8d2d30d79dcdba757d7 --- /dev/null +++ b/paddle/fluid/inference/api/details/reset_tensor_array.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace details { + +// Clean the TensorArray each batch to make the behavior the same with the +// training phase. +struct TensorArrayBatchCleaner { + // Fix the tensor array not clear in the inference scenarios. + void CollectTensorArrays(framework::Scope *scope); + void ResetTensorArray(); + + private: + bool flag_{true}; + std::vector arrays_; +}; + +} // namespace details +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 07ee6e72d1053d2271b8f8d69ce38003f5e038a0..a755ccb93bdee018dfeaf91157e7971b4d4cd832 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -124,7 +124,7 @@ class ZeroCopyTensor { std::vector> lod() const; protected: - ZeroCopyTensor(void* scope) : scope_{scope} {} + explicit ZeroCopyTensor(void* scope) : scope_{scope} {} void SetName(const std::string& name) { name_ = name; } void* FindTensor() const; @@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig { kExclude // Specify the disabled passes in `ir_passes`. }; - void SetIncludeMode() { - ir_mode = IrPassMode::kInclude; - // this pass has to be run at the beginning of all fuse passes - ir_passes = {"infer_clean_graph_pass"}; - } - // Determine whether to perform graph optimization. bool enable_ir_optim = true; // Manually determine the IR passes to run. diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index 6399476680c0af83a6d26aea952c58543bdce9ae..e0416ff953b61f56a2ca1a45cb382d40a6cffa4a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -228,6 +228,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_rnn1, profile) { contrib::AnalysisConfig cfg; SetConfig(&cfg); + cfg.use_gpu = false; std::vector outputs; std::vector> input_slots_all; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 78ef6f207eadea6799864fe22889103b468d1780..0d51cb92618170cb422cb49ba63ba54ae6608ef4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() +op_library(hash_op DEPS xxhash) op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index b6cb935814e25b31d4104f9ce24fe952680cb491..0d32cae0e1e5ff274793df50e854283d8e2f7bf8 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -79,6 +79,9 @@ struct BeamSearchDecodeFunctor { bool tensor_on_gpu_; size_t beam_size_; int end_id_; + // TODO(Superjomn) Here might result serious performance issue in the + // concurrency + // scenarios. const LoDTensorArray& step_ids_origin_; const LoDTensorArray& step_scores_origin_; LoDTensorArray step_ids_ = LoDTensorArray(); diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index a69d9c9a529f26b3981ca8d1ba226fda71b8820a..709c2dfc4b7c67d7d04074c58ce6da85b6e790fe 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, selected_indices.push_back(idx); ++selected_num; } - sorted_indices.erase(sorted_indices.end()); + sorted_indices.erase(sorted_indices.end() - 1); if (flag && eta < 1 && adaptive_threshold > 0.5) { adaptive_threshold *= eta; } diff --git a/paddle/fluid/operators/fake_init_op.cc b/paddle/fluid/operators/fake_init_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..28ebdcb03ea83f3ec701106111a7cc5f0f7ed7dc --- /dev/null +++ b/paddle/fluid/operators/fake_init_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +class FakeInitInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FakeInitOp should not be null."); + auto &shape = ctx->Attrs().Get>("shape"); + ctx->SetOutputDim("Out", framework::make_ddim(shape)); + } +}; + +class FakeInitOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + framework::Tensor *tensor = nullptr; + + auto &out_var = *scope.FindVar(Output("Out")); + + if (out_var.IsType()) { + tensor = out_var.GetMutable(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else if (out_var.IsType()) { + tensor = out_var.GetMutable()->mutable_value(); + tensor->Resize(framework::make_ddim(Attr>("shape"))); + } else { + PADDLE_THROW( + "fake init op's output only" + "supports SelectedRows and LoDTensor"); + } + } +}; + +class FakeInitOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; + +class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddAttr>("shape", + "(vector) The shape of the output"); + AddOutput("Out", + "(Tensor) Tensor of specified shape will be filled " + "with the specified value"); + AddComment(R"DOC( +FakeInit Operator. + +Init an variable but not alloc memory for it, it is used for init the +table parameter at trainer side in distributed lookup table. + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fake_init, ops::FakeInitOp, ops::FakeInitInferShape, + ops::FakeInitOpMaker, paddle::framework::EmptyGradOpMaker, + ops::FakeInitOpVarTypeInference); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index e04a68717b351ddb0be5a7e70aa9297e5eb0125f..252f313440296bd9e5eebf26f67b08bbe7decce8 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -24,7 +24,7 @@ class FillConstantInferShape : public framework::InferShapeBase { void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FillConstantOp should not be null."); - auto &shape = ctx->Attrs().Get>("shape"); + auto &shape = ctx->Attrs().Get>("shape"); ctx->SetOutputDim("Out", framework::make_ddim(shape)); } }; @@ -47,10 +47,10 @@ class FillConstantOp : public framework::OperatorBase { if (out_var.IsType()) { tensor = out_var.GetMutable(); - tensor->Resize(framework::make_ddim(Attr>("shape"))); + tensor->Resize(framework::make_ddim(Attr>("shape"))); } else if (out_var.IsType()) { tensor = out_var.GetMutable()->mutable_value(); - tensor->Resize(framework::make_ddim(Attr>("shape"))); + tensor->Resize(framework::make_ddim(Attr>("shape"))); } else { PADDLE_THROW( "fill constant op's output only" @@ -83,7 +83,8 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { "(int, default 5 (FP32)) " "Output data type") .SetDefault(framework::proto::VarType::FP32); - AddAttr>("shape", "(vector) The shape of the output"); + AddAttr>("shape", + "(vector) The shape of the output"); AddAttr("value", "(float, default 0) The value to be filled") .SetDefault(0.0f); AddAttr("force_cpu", diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 1488aab1926b5b4ba7bceed582700f5a11fc6c93..c70d5b8bc7569c38cbc003aca7c62dc503df11cf 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -52,7 +52,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of GaussianRandomOp should not be null."); - auto shape = ctx->Attrs().Get>("shape"); + auto shape = ctx->Attrs().Get>("shape"); std::vector temp; temp.reserve(shape.size()); for (auto dim : shape) { @@ -88,9 +88,9 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddOutput("Out", "Output matrix of gaussian random op"); - AddAttr>("shape", - "(vector) " - "The dimension of random tensor."); + AddAttr>("shape", + "(vector) " + "The dimension of random tensor."); AddAttr("mean", "(float, default 0.0) " "mean of random tensor.") diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9ebe71a3d7ae270a10a45f4805652415078b363 --- /dev/null +++ b/paddle/fluid/operators/hash_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/hash_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class HashOp : public framework::OperatorWithKernel { + public: + HashOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of HashOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of HashOp should not be null."); + + auto dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(dims.size(), 2UL, + "The input of hash_op's dimensions must be 2"); + std::vector out_dims; + out_dims.reserve(dims.size() + 1); + // copy all dims except the last one + for (size_t i = 0u; i != dims.size() - 1; ++i) { + out_dims.emplace_back(dims[i]); + } + int num_hash = ctx->Attrs().Get("num_hash"); + out_dims.emplace_back(num_hash); + // keep the last dim to 1 + out_dims.emplace_back(1); + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class HashOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input tensor of scale operator."); + AddOutput("Out", "(Tensor) Output tensor of scale operator."); + AddComment(R"DOC( +**Hash Operator** +$$Out = scale * X$$ +)DOC"); + AddAttr("num_hash", "").SetDefault(1); + AddAttr("mod_by", "").SetDefault(100000); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker); +REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel, ops::HashKerel); diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9781bb0f453642cefb3eb59a05389c339a7de39d --- /dev/null +++ b/paddle/fluid/operators/hash_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +extern "C" { +#include +} +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +// template +template +class HashKerel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* out_t = context.Output("Out"); + auto* in_t = context.Input("X"); + int mod_by = context.Attr("mod_by"); + int num_hash = context.Attr("num_hash"); + auto* output = out_t->mutable_data(context.GetPlace()); + + auto in_dims = in_t->dims(); + auto in_lod = in_t->lod(); + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + auto seq_length = in_dims[0]; + auto last_dim = in_dims[in_dims.size() - 1]; + auto* input = in_t->data(); + for (int idx = 0; idx < seq_length; ++idx) { + for (int ihash = 0; ihash != num_hash; ++ihash) { + output[idx * num_hash + ihash] = + XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; + } + input += last_dim; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lars_momentum_op.cc b/paddle/fluid/operators/lars_momentum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8dda93902448fa1bd21b719ffd9c9b500caf755 --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/lars_momentum_op.h" +#include "paddle/fluid/operators/momentum_op.h" + +namespace paddle { +namespace operators { + +class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", + "(LoDTensor, default LoDTensor) " + "Input parameter that has to be updated"); + AddInput("Grad", + "(LoDTensor, default LoDTensor) " + "Input gradient of the parameter"); + AddInput("Velocity", + "(LoDTensor, default LoDTensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated"); + AddInput("LearningRate", + "(LoDTensor, default LoDTensor) " + "Input learning rate"); + + AddOutput("ParamOut", + "(LoDTensor) This output is updated parameter. " + "It shared memory with Input(Param)."); + AddOutput("VelocityOut", + "(LoDTensor) This output is updated velocity. " + "It shared memory with Input(Velocity)."); + + AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("lars_coeff", "(float, default 0.001) LARS coefficient.") + .SetDefault(0.001); + AddAttr("lars_weight_decay", + "(float, default 0.0005) LARS weight decay") + .SetDefault(0.0005); + + AddComment(R"DOC( +Lars Momentum Optimizer. + +This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each +weight using a local learning rate: + +$$ +local\_lr = \eta * + \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ +velocity = mu * velocity + + local\_lr * (grad + \beta * param) \\ +param = param - velocity. \\ +$$ + +Note that we use lars_weight_decay here to decay weights, you may need not to +use L2 regularizers in case of using LARS. + +)DOC"); + } +}; + +class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::LarsMomentumOpVarTypeInference); +REGISTER_OP_CPU_KERNEL(lars_momentum, ops::LarsMomentumOpKernel, + ops::LarsMomentumOpKernel); diff --git a/paddle/fluid/operators/lars_momentum_op.cu b/paddle/fluid/operators/lars_momentum_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb346851a2f690fa05422c84ddcb08307539048f --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.cu @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/lars_momentum_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, + const int64_t num, const T lars_coeff, + const T lars_weight_decay, const T* p_norm, + const T* g_norm, T* p_out, T* v_out) { + T lr = learning_rate[0]; + T local_lr = learning_rate[0]; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + if (p_norm[0] > 0 && g_norm[0] > 0) { + local_lr = lr * lars_coeff * p_norm[0] / + (g_norm[0] + lars_weight_decay * p_norm[0]); + } + T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]); + v_out[i] = v_new; + p_out[i] = p[i] - v_new; + } +} + +template +class LarsMomentumOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto grad = ctx.Input("Grad"); + auto learning_rate = ctx.Input("LearningRate"); + + T* p_out = param_out->mutable_data(ctx.GetPlace()); + T* v_out = velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T lars_weight_decay = ctx.Attr("lars_weight_decay"); + + auto* p = param->data(); + auto* v = velocity->data(); + auto* g = grad->data(); + auto* lr = learning_rate->data(); + + int block = 512; + int grid = (param->numel() + block - 1) / block; + + auto eigen_p = framework::EigenVector::Flatten(*param); + auto eigen_g = framework::EigenVector::Flatten(*grad); + // calculate norms using eigein and launch the kernel. + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + auto* p_norm_data = p_norm_t.mutable_data(ctx.GetPlace()); + auto* g_norm_data = g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + + auto* place = ctx.template device_context().eigen_device(); + ep_norm.device(*place) = eigen_p.square().sum().sqrt(); + eg_norm.device(*place) = eigen_g.square().sum().sqrt(); + MomentumLarsKernel<<>>( + p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay, + p_norm_data, g_norm_data, p_out, v_out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + lars_momentum, + ops::LarsMomentumOpCUDAKernel, + ops::LarsMomentumOpCUDAKernel); diff --git a/paddle/fluid/operators/lars_momentum_op.h b/paddle/fluid/operators/lars_momentum_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e85be99fc42522e461a7915847d82144d8195a96 --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LarsMomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto learning_rate = ctx.Input("LearningRate"); + auto* grad_var = ctx.InputVar("Grad"); + // only support dense for now. + PADDLE_ENFORCE(grad_var->IsType()); + auto grad = ctx.Input("Grad"); + + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T lars_weight_decay = ctx.Attr("lars_weight_decay"); + + auto p_out = framework::EigenVector::Flatten(*param_out); + auto v_out = framework::EigenVector::Flatten(*velocity_out); + + auto p = framework::EigenVector::Flatten(*param); + auto v = framework::EigenVector::Flatten(*velocity); + auto g = framework::EigenVector::Flatten(*grad); + auto* lr = learning_rate->data(); + + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + p_norm_t.mutable_data(ctx.GetPlace()); + g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + + ep_norm = p.square().sum().sqrt(); + eg_norm = g.square().sum().sqrt(); + T local_lr = lr[0]; + if (ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0)); + } + v_out = v * mu + local_lr * (g + lars_weight_decay * p); + p_out = p - v_out; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 26f09c46c2224a4a46d302dff4b2ec594f0be103..a038bad701ba8ede3065af9f352f1f21784a50b7 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -27,6 +27,10 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/listen_and_serv_op.h" +DEFINE_int32(rpc_send_thread_num, 5, "number of threads for rpc send"); +DEFINE_int32(rpc_get_thread_num, 5, "number of threads for rpc get"); +DEFINE_int32(rpc_prefetch_thread_num, 5, "number of threads for rpc prefetch"); + namespace paddle { namespace operators { @@ -332,11 +336,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, sync_mode, checkpoint_block_id)); rpc_service_->RegisterRPC(distributed::kRequestSend, - request_send_handler_.get()); + request_send_handler_.get(), + FLAGS_rpc_send_thread_num); rpc_service_->RegisterRPC(distributed::kRequestGet, - request_get_handler_.get()); + request_get_handler_.get(), + FLAGS_rpc_get_thread_num); rpc_service_->RegisterRPC(distributed::kRequestPrefetch, - request_prefetch_handler_.get()); + request_prefetch_handler_.get(), + FLAGS_rpc_prefetch_thread_num); rpc_service_->RegisterRPC(distributed::kRequestCheckpoint, request_checkpoint_handler_.get()); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index b9ac54e446811889b647397ae1fbb11c28f46777..3226a727b1f5f6de9e97ce2068381be7c9b69ff3 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -81,6 +81,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { "Otherwise the given value indicates padding the output " "with zeros whenever lookup encounters it in Ids.") .SetDefault(kNoPadding); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); AddComment(R"DOC( Lookup Table Operator. @@ -115,7 +121,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 58463dc4d6fd7cc3454de766814a947fee161070..e504c4f0cd5c0feaef4a251fad57b389a10a2ce7 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -68,6 +69,7 @@ class LookupTableKernel : public framework::OpKernel { const auto *table = table_t.value().data(); auto *output = output_t->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T)); @@ -75,8 +77,8 @@ class LookupTableKernel : public framework::OpKernel { PADDLE_ENFORCE_GE(ids[i], 0); auto id_index = table_t.Index(ids[i]); PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); - memcpy(output + i * row_width, table + id_index * row_width, - row_width * sizeof(T)); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); } } } @@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); - framework::Vector new_rows; - new_rows.reserve(ids_num); - for (int64_t i = 0; i < ids_num; i++) { - new_rows.push_back(ids_data[i]); - } + std::vector new_rows; + new_rows.resize(ids_num); + std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->mutable_data(context.GetPlace()); - - d_table->set_height(table_dim[0]); - - auto *d_output_data = d_output->data(); - auto *d_table_data = d_table_value->data(); - - auto d_output_dims = d_output->dims(); - PADDLE_ENFORCE_EQ( - d_table_value->dims(), - framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); - memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + // FIXME(minqiyang): + // memory optimization will NOT reuse Tensor with SelectedRows + // so we could just share the tensor here directly. + // However, the InferVarType method will infer the output SelectedRows + // to Tensor sometimes, which is a bug, so we will add an attribute + // here to indicate the inplace and remove this attribute after + // the InferVarType's bug was fixed + bool grad_inplace = context.Attr("grad_inplace"); + if (grad_inplace) { + d_table_value->ShareDataWith(*d_output); + } else { + d_table_value->mutable_data(context.GetPlace()); + + d_table->set_height(table_dim[0]); + + auto *d_output_data = d_output->data(); + auto *d_table_data = d_table_value->data(); + + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } } else { auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 08f57dd45ad76946cbcafb98a3414003ed9d67a9..75946740375d74043960b68e94eb048b3bab4b79 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include -#include +#include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" @@ -230,8 +229,24 @@ template struct SelectedRowsAddToTensor; // add or mul. namespace scatter { -size_t FindPos(const std::vector& rows, int64_t value) { - return std::find(rows.begin(), rows.end(), value) - rows.begin(); +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add_to(const DeviceContext& ctx, BlasT* blas, + size_t data_len, const T* in, T* out) { + blas->AXPY(data_len, 1., in, out); +} + +template +typename std::enable_if< + !std::is_floating_point::value && + std::is_same::value>::type +elementwise_add_to(const DeviceContext& ctx, BlasT* blas, + size_t data_len, const T* in, T* out) { + for (int64_t i = 0; i < data_len; i++) { + out[i] += in[i]; + } } template @@ -246,48 +261,84 @@ struct MergeAdd { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input, framework::SelectedRows* output) { - framework::SelectedRows& out = *output; - std::vector input_rows(input.rows()); + std::vector inputs; + inputs.push_back(&input); + (*this)(context, inputs, output); + } - std::map> merge_row_map; - for (size_t i = 0; i < input_rows.size(); ++i) { - merge_row_map[input_rows[i]].push_back(i); + void operator()(const platform::CPUDeviceContext& context, + const std::vector& inputs, + framework::SelectedRows* output) { + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; } - - std::vector merge_rows(merge_row_map.size()); - size_t idx = 0; - int64_t input_width = input.value().dims()[1]; - out.set_height(input.height()); - - T* out_data = out.mutable_value()->mutable_data( + const framework::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (in->rows().size() > 0) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); + framework::SelectedRows& out = *output; + std::set merged_row_set; + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], + "all input should have same " + "dimension except for the first one"); + PADDLE_ENFORCE_EQ(input_height, input->height(), + "all input should have same height"); + merged_row_set.insert(input->rows().begin(), input->rows().end()); + } + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; + } + out.set_rows(merge_rows); + out.set_height(input_height); + out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); - const T* in_data = input.value().data(); - - for (auto& row_pair : merge_row_map) { - auto* out_ptr = out_data + idx * input_width; - auto& rows = row_pair.second; - merge_rows[idx] = row_pair.first; - ++idx; - // rows.size() is always larger than 0 - std::memcpy(out_ptr, in_data + rows[0] * input_width, - sizeof(T) * input_width); - - for (size_t i = 1; i < rows.size(); ++i) { - auto* in_ptr = in_data + rows[i] * input_width; - for (int64_t j = 0; j < input_width; ++j) { - out_ptr[j] += in_ptr[j]; - } + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + + auto blas = math::GetBlas(context); + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + elementwise_add_to( + context, &blas, static_cast(input_width), + &input_data[i * input_width], &out_data[out_i * input_width]); } } - - out.set_rows(merge_rows); } }; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template struct UpdateToTensor { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index ba8eccf82042b679f69a32f9d053f05ac8fb9a99..10f39822b9c904ce236a1a2a3806d70693bd2e63 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -267,10 +267,15 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input, framework::SelectedRows* output) { - framework::SelectedRows& out = *output; framework::Vector input_rows(input.rows()); + if (input_rows.size() == 0) { + return; + } + + framework::SelectedRows& out = *output; std::set row_set(input_rows.begin(), input_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); + std::vector merge_rows_cpu(row_set.begin(), row_set.end()); + framework::Vector merge_rows(merge_rows_cpu); auto input_width = input.value().dims()[1]; @@ -296,6 +301,73 @@ struct MergeAdd { out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.rows().size(), input_width); } + + void operator()(const platform::CUDADeviceContext& context, + const std::vector& inputs, + framework::SelectedRows* output) { + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; + } + const framework::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (in->rows().size() > 0) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); + framework::SelectedRows& out = *output; + std::set merged_row_set; + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], + "all input should have same " + "dimension except for the first one"); + PADDLE_ENFORCE_EQ(input_height, input->height(), + "all input should have same height"); + merged_row_set.insert(input->rows().begin(), input->rows().end()); + } + std::vector merge_rows_cpu(merged_row_set.begin(), + merged_row_set.end()); + framework::Vector merge_rows(merge_rows_cpu); + + out.set_rows(merge_rows); + out.set_height(input_height); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + dim3 grid1(input_rows.size(), 1); + + MergeAddKernel<<>>( + input_data, input_rows.CUDAData(context.GetPlace()), out_data, + out.mutable_rows()->CUDAMutableData(context.GetPlace()), + out.rows().size(), input_width); + } + } }; template struct MergeAdd; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 900be86f91c6658a5265189a6745316c6471209e..521c53dd0d71707c13c4364c5ee59943a03d4a2d 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -83,104 +83,9 @@ struct MergeAdd { void operator()(const DeviceContext& context, const framework::SelectedRows& input, framework::SelectedRows* output); -}; - -template <> -struct MergeAdd { - framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { - framework::SelectedRows out; - (*this)(context, input, &out); - return out; - } - - void operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* output) { - framework::SelectedRows& out = *output; - std::vector input_rows(input.rows()); - - std::map> merge_row_map; - for (size_t i = 0; i < input_rows.size(); ++i) { - merge_row_map[input_rows[i]].push_back(i); - } - - std::vector merge_rows(merge_row_map.size()); - size_t idx = 0; - int64_t input_width = input.value().dims()[1]; - out.set_height(input.height()); - - auto* out_data = out.mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), - context.GetPlace()); - auto* in_data = input.value().data(); - - auto blas = GetBlas(context); - for (auto& row_pair : merge_row_map) { - auto* out_ptr = out_data + idx * input_width; - auto& rows = row_pair.second; - merge_rows[idx] = row_pair.first; - ++idx; - // rows.size() is always larger than 0 - blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr); - - for (size_t i = 1; i < rows.size(); ++i) { - blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr); - } - } - - out.set_rows(merge_rows); - } -}; - -template <> -struct MergeAdd { - framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { - framework::SelectedRows out; - (*this)(context, input, &out); - return out; - } - - void operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* output) { - framework::SelectedRows& out = *output; - std::vector input_rows(input.rows()); - - std::map> merge_row_map; - for (size_t i = 0; i < input_rows.size(); ++i) { - merge_row_map[input_rows[i]].push_back(i); - } - - std::vector merge_rows(merge_row_map.size()); - size_t idx = 0; - int64_t input_width = input.value().dims()[1]; - out.set_height(input.height()); - - auto* out_data = out.mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), - context.GetPlace()); - auto* in_data = input.value().data(); - - auto blas = GetBlas(context); - for (auto& row_pair : merge_row_map) { - auto* out_ptr = out_data + idx * input_width; - auto& rows = row_pair.second; - merge_rows[idx] = row_pair.first; - ++idx; - // rows.size() is always larger than 0 - blas.VCOPY(input_width, in_data + rows[0] * input_width, out_ptr); - - for (size_t i = 1; i < rows.size(); ++i) { - blas.AXPY(input_width, 1., in_data + rows[i] * input_width, out_ptr); - } - } - - out.set_rows(merge_rows); - } + void operator()(const DeviceContext& context, + const std::vector& inputs, + framework::SelectedRows* output); }; template diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index 835589356042b44c9fa5988aed726434fd66910a..f15b37a1e3f0ae9c7612c4f74470472393ff4ad6 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -302,6 +302,64 @@ TEST(selected_rows_functor, cpu_merge_add_int) { EXPECT_EQ(out_data[1 * row_numel], 2); EXPECT_EQ(out_data[2 * row_numel], 1); } + +TEST(selected_rows_functor, cpu_merge_add_multi) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + set_const; + + int64_t height = 10; + int64_t row_numel = 8; + + std::vector rows1{5, 2, 5, 3, 5}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + cpu_place); + set_const(ctx, in1_value, 1.0); + + std::vector rows2{2, 5, 3, 5, 3}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + cpu_place); + set_const(ctx, in2_value, 1.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + + std::vector inputs; + inputs.push_back(selected_rows1.get()); + inputs.push_back(selected_rows2.get()); + merge_add_functor(ctx, inputs, output.get()); + + EXPECT_EQ(output->height(), height); + EXPECT_EQ(output->value().dims(), + paddle::framework::make_ddim({3, row_numel})); + + std::vector ret_rows{2, 3, 5}; + EXPECT_EQ(output->rows(), ret_rows); + + auto* out_data = output->value().data(); + for (size_t i = 0; i < ret_rows.size(); ++i) { + for (size_t j = 0; j < row_numel; ++j) { + EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); + } + } +} + TEST(selected_rows_functor, cpu_sum_to) { paddle::platform::CPUPlace cpu_place; paddle::platform::CPUDeviceContext ctx(cpu_place); @@ -318,6 +376,7 @@ TEST(selected_rows_functor, cpu_sum_to) { paddle::framework::make_ddim( {static_cast(rows1.size()), row_numel}), cpu_place); + functor(ctx, in1_value, 1.0); std::vector rows2{0, 5, 7, 9}; std::unique_ptr selected_rows2{ @@ -327,6 +386,7 @@ TEST(selected_rows_functor, cpu_sum_to) { paddle::framework::make_ddim( {static_cast(rows2.size()), row_numel}), cpu_place); + functor(ctx, in2_value, 2.0); std::unique_ptr output{ new paddle::framework::SelectedRows()}; diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu index 5fc50aba25d8e69480a17f0f80877b0d03e17276..17af3e3999ca688c584f636f4c00386f886f9bbf 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu @@ -241,3 +241,67 @@ TEST(selected_rows_functor, gpu_add_to) { // row9: 2.0 + 3.0 EXPECT_EQ(tensor1_cpu_data[9 * row_numel + 6], 5.0); } + +TEST(selected_rows_functor, gpu_merge_add) { + paddle::platform::CUDAPlace gpu_place(0); + paddle::platform::CPUPlace cpu_place; + paddle::platform::CUDADeviceContext& ctx = + *reinterpret_cast( + paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); + paddle::operators::math::SetConstant + set_const; + + int64_t height = 10; + int64_t row_numel = 8; + + std::vector rows1{5, 2, 5, 3, 5}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + gpu_place); + set_const(ctx, in1_value, 1.0); + + std::vector rows2{2, 5, 3, 5, 3}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + gpu_place); + set_const(ctx, in2_value, 1.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + paddle::operators::math::scatter::MergeAdd< + paddle::platform::CUDADeviceContext, float> + merge_add_functor; + + std::vector inputs; + inputs.push_back(selected_rows1.get()); + inputs.push_back(selected_rows2.get()); + merge_add_functor(ctx, inputs, output.get()); + + paddle::framework::Tensor output_cpu; + paddle::framework::TensorCopy(output->value(), cpu_place, ctx, &output_cpu); + ctx.Wait(); + + EXPECT_EQ(output->height(), height); + EXPECT_EQ(output->value().dims(), + paddle::framework::make_ddim({3, row_numel})); + + std::vector ret_rows{2, 3, 5}; + EXPECT_EQ(output->rows(), ret_rows); + + auto* out_data = output_cpu.data(); + for (size_t i = 0; i < ret_rows.size(); ++i) { + for (size_t j = 0; j < row_numel; ++j) { + EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); + } + } +} diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index c6ec4ab047d5e91625e646fd26108d2e477cdce5..6e0e13698097ade36449f2e8ff6ab981a1b24311 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -20,13 +20,16 @@ namespace operators { class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); - AddInput( - "X", - "(LoDTensors) multi input tensor with shape{batch_num, N}, N is the " - "size of embedding table") + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}") + .AsDuplicable(); + AddInput("Rows", "(LoDTensor) the input ids with shape{row_size, 1}, ") + .AsDuplicable(); + AddInput("X", + "(LoDTensors) multi input tensor with shape{Rows, N}, N is the " + "size of embedding table") + .AsDuplicable(); + AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.") .AsDuplicable(); - AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddComment(R"DOC( Merge multi LoDTensor's into one according to Ids's shard num. @@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids."); - PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out."); + PADDLE_ENFORCE(ctx->HasInputs("Ids"), + "MergeIdsOp must has multi input Ids."); + PADDLE_ENFORCE(ctx->HasInputs("Rows"), + "MergeIdsOp must has multi input Rows."); + PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has multi input X."); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), + "MergeIdsOp must has multi output Out."); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - auto ids_dims = ctx->GetInputDim("Ids"); + auto ids_dims = ctx->GetInputsDim("Ids"); if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[0][1], 1); } auto x_var_type = ctx->GetInputsVarType("X"); for (auto &var_type : x_var_type) { diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index 83712a8519c6817151e1922c606c0fdd4682a2db..fef9e023d02f45e21ec409ad398ba7d9bdd36880 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" @@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel { if (!platform::is_cpu_place(place)) { PADDLE_THROW("MergeIds do not support GPU kernel"); } - VLOG(3) << "run in MergeIdsOpKernel"; - const auto *ids_var = ctx.InputVar("Ids"); - PADDLE_ENFORCE(ids_var->IsType(), - "only support to merge Ids of LoDTensor"); + const auto ids = ctx.MultiInput("Ids"); + const auto row_ids = ctx.MultiInput("Rows"); + const auto x_tensors = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); - const auto &ids_tensor = ids_var->Get(); - const auto &ids_dims = ids_tensor.dims(); - const int64_t *ids = ids_tensor.data(); + PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(), + "the number of Rows and X should be the same"); + PADDLE_ENFORCE_EQ(ids.size(), outs.size(), + "the number of Ids and Out should be the same"); - auto x_tensors = ctx.MultiInput("X"); + int row_ids_size = 0; + int row_size = 0; + int embedding_size = 0; - auto *out = ctx.Output("Out"); + for (int i = 0; i < x_tensors.size(); ++i) { + const auto *x_tensor = x_tensors[i]; + const auto *row_id = row_ids[i]; - int batch_size = 0; - int embedding_size = 0; - for (auto &input : x_tensors) { - if (framework::product(input->dims()) != 0) { - if (embedding_size == 0) { - embedding_size = input->dims()[1]; - } - PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], - "embedding size of all input should be the same"); - batch_size += input->dims()[0]; + if (embedding_size == 0) { + embedding_size = x_tensor->dims()[1]; } + PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1], + "embedding size of all input should be the same"); + row_size += x_tensor->dims()[0]; + row_ids_size += row_id->dims()[0]; } + PADDLE_ENFORCE_EQ( - batch_size, ids_dims[0], - "the batch size of ids and merged embedding value should be the same"); + row_size, row_ids_size, + "the merged X dim[0] and merged Rows dim[0] should be the same"); + + std::unordered_map> + selected_rows_idx_map; + for (int i = 0; i < x_tensors.size(); ++i) { + const auto *row_id = row_ids[i]; + + for (int j = 0; j < row_id->numel(); ++j) { + int64_t key = row_id->data()[j]; + std::tuple val = std::make_tuple(i, j); + selected_rows_idx_map.insert(std::make_pair(key, val)); + } + } + PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(), + "the rows and tensor map size should be the same"); + + for (int i = 0; i < outs.size(); ++i) { + auto *out_ids = ids[i]; + auto *out = outs[i]; - const size_t shard_num = x_tensors.size(); + out->set_lod(out_ids->lod()); - if (shard_num == 1) { - VLOG(3) << "only one shard, we can copy the data directly"; - TensorCopy(*x_tensors[0], place, out); - } else { - std::vector in_indexs(shard_num, 0); + int nums = static_cast(out_ids->dims()[0]); auto *out_data = out->mutable_data( - framework::make_ddim({batch_size, embedding_size}), place); - // copy data from ins[shard_num] to out. - for (int i = 0; i < ids_dims[0]; ++i) { - int64_t id = ids[i]; - size_t shard_id = static_cast(id) % shard_num; - int index = in_indexs[shard_id]; - memcpy(out_data + embedding_size * i, - x_tensors[shard_id]->data() + index * embedding_size, + framework::make_ddim({nums, embedding_size}), place); + for (int j = 0; j < nums; ++j) { + int id = out_ids->data()[j]; + auto row_tuple = selected_rows_idx_map[id]; + int64_t row_idx = std::get<1>(row_tuple); + const auto *x_tensor = x_tensors[std::get<0>(row_tuple)]; + + memcpy(out_data + embedding_size * j, + x_tensor->data() + row_idx * embedding_size, sizeof(T) * embedding_size); - in_indexs[shard_id] += 1; - } - - for (size_t i = 0; i < shard_num; ++i) { - PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], - "after merge, all data in x_tensor should be used"); } } } diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 12b916fcebd425bd4a03d920f947829098a924a1..7f0b51580aa2591ac7338ad7c29ee4756d909925 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -19,54 +19,6 @@ namespace operators { using Tensor = framework::Tensor; -class MomentumOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(param) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(grad) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Velocity"), - "Input(velocity) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of Momentum should not be null."); - PADDLE_ENFORCE( - ctx->GetInputsVarType("Param").front() == - framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); - - PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), - "Output(ParamOut) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), - "Output(VelocityOut) of Momentum should not be null."); - - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - "Param and Grad input of MomentumOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Velocity"), - "Param and Velocity of MomentumOp should have the same dimension."); - } - PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, - "Learning_rate should be a scalar"); - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("VelocityOut", param_dim); - } - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - } -}; - class MomentumOpInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc& op_desc, diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 6b4d00f56ca06c402c07ecf770a390e88ae3edf1..71f079e4d97f5259359ee6572f584894551452ca 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -28,6 +28,54 @@ using framework::SelectedRows; struct NoNesterov; struct UseNesterov; +class MomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(param) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(grad) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Velocity"), + "Input(velocity) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of Momentum should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), + "Output(VelocityOut) of Momentum should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + } + PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, + "Learning_rate should be a scalar"); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("VelocityOut", param_dim); + } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + template class CPUDenseMomentumFunctor { private: diff --git a/paddle/fluid/operators/split_ids_op.cc b/paddle/fluid/operators/split_ids_op.cc index c867c46873ae7ddbdbda280351e4ab28235bcc08..243f81e296fb95a2c7e9f717950b8a958ad98852 100644 --- a/paddle/fluid/operators/split_ids_op.cc +++ b/paddle/fluid/operators/split_ids_op.cc @@ -20,20 +20,27 @@ namespace operators { class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); - AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}") + .AsDuplicable(); + + AddOutput("Out", "(LoDTensors) The outputs of the input Ids.") .AsDuplicable(); AddComment(R"DOC( Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number Example: Input: - X = [1,2,3,4,5,6] + X = [[1,2,3,4,5,6],[2,3]] Out(3 output): - out0 = [3, 6] - out1 = [1, 4] - out2 = [2, 5] + if compress is True: + out0 = [3, 3, 6] + out1 = [1, 4] + out2 = [2, 2, 5] + else: + out0 = [3, 6] + out1 = [1, 4] + out2 = [2, 5] )DOC"); } }; @@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids."); + PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must has input Ids."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - auto ids_dims = ctx->GetInputDim("Ids"); + auto ids_dims = ctx->GetInputsDim("Ids"); if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.MultiInput("Ids").front()->type()), + ctx.GetPlace()); + } }; class SplitIdsOpInferVarType : public framework::VarTypeInference { @@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference { } }; +class SplitIdsOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto grad = new framework::OpDesc(); + grad->SetType("concat"); + grad->SetInput("X", OutputGrad("Out")); + grad->SetOutput("Out", InputGrad("Ids")); + grad->SetAttr("axis", 0); + return std::unique_ptr(grad); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, - ops::SplitIdsOpInferVarType); + ops::SplitIdsOpGradMaker, ops::SplitIdsOpInferVarType); + REGISTER_OP_CPU_KERNEL( split_ids, ops::SplitIdsOpKernel, ops::SplitIdsOpKernel); diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h index c4af5a65fc5f81c1af7c1fdcca637ca37c940637..69ac6c5a6b9a8b318520eb9a3ff89a3a6be48339 100644 --- a/paddle/fluid/operators/split_ids_op.h +++ b/paddle/fluid/operators/split_ids_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include #include #include "paddle/fluid/framework/op_registry.h" @@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel { PADDLE_THROW("SplitIds do not support GPU kernel"); } - const auto *ids_var = ctx.InputVar("Ids"); + const auto ids_vars = ctx.MultiInputVar("Ids"); + + PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0"); + auto *ids_var = ids_vars[0]; + if (ids_var->IsType()) { - const auto &ids_dims = ctx.Input("Ids")->dims(); - const T *ids = ctx.Input("Ids")->data(); + int batch_size = 0; + const auto ids_tensors = ctx.MultiInput("Ids"); + for (size_t i = 0; i < ids_tensors.size(); ++i) { + batch_size += ids_tensors[i]->dims()[0]; + } + VLOG(4) << "Get Total BatchSize is: " << batch_size; + + std::vector all_ids(batch_size); + int offset = 0; + for (size_t i = 0; i < ids_tensors.size(); ++i) { + const auto *ids = ids_tensors[i]; + std::memcpy(all_ids.data() + offset, ids->data(), + ids->numel() * sizeof(T)); + offset += ids->numel(); + } + + std::set st(all_ids.begin(), all_ids.end()); + all_ids.assign(st.begin(), st.end()); + auto outs = ctx.MultiOutput("Out"); const size_t shard_num = outs.size(); - std::vector> out_ids; out_ids.resize(outs.size()); // split id by their shard_num. - for (int i = 0; i < ids_dims[0]; ++i) { - T id = ids[i]; + for (int i = 0; i < all_ids.size(); ++i) { + T id = all_ids[i]; size_t shard_id = static_cast(id) % shard_num; out_ids[shard_id].push_back(id); } @@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(ids_dims[0], static_cast(ids_selected_rows->rows().size()), ""); - const T *ids = ids_selected_rows->value().data(); + const T *ids_data = ids_selected_rows->value().data(); const auto &ids_rows = ids_selected_rows->rows(); auto outs = ctx.MultiOutput("Out"); const size_t shard_num = outs.size(); @@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel { T *output = out->mutable_value()->mutable_data(ddim, place); for (int64_t i = 0; i < ddim[0]; ++i) { memcpy(output + i * row_width, - ids + id_to_index[out->rows()[i]] * row_width, + ids_data + id_to_index[out->rows()[i]] * row_width, row_width * sizeof(T)); } } diff --git a/paddle/fluid/operators/split_selected_rows_op.cc b/paddle/fluid/operators/split_selected_rows_op.cc index 76615a9405d7a8e3fa9dba8d01a956209e02ae8f..0e7b1463d1ba81aed53e0e3f3a90d2a1fbf0ffbc 100644 --- a/paddle/fluid/operators/split_selected_rows_op.cc +++ b/paddle/fluid/operators/split_selected_rows_op.cc @@ -22,9 +22,9 @@ class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "The input SelectedRows."); AddOutput("Out", "The outputs of the input SelectedRows.").AsDuplicable(); - AddAttr>("height_sections", - "Height for each output SelectedRows.") - .SetDefault(std::vector({})); + AddAttr>("height_sections", + "Height for each output SelectedRows.") + .SetDefault(std::vector({})); AddComment(R"DOC( Split a SelectedRows with a specified rows section. diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index 0e9ce165b98845f4745ee70b028513ea31cc6657..af64607fafc6544047714e731846a2440be219b8 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -static int FindOutIdx(int row, const std::vector& abs_sections) { +static int FindOutIdx(int row, const std::vector& abs_sections) { for (size_t i = 1; i < abs_sections.size(); ++i) { if (row < abs_sections[i]) { return i - 1; @@ -30,9 +30,9 @@ static int FindOutIdx(int row, const std::vector& abs_sections) { return abs_sections.size() - 1; } -static std::vector ToAbsoluteSection( - const std::vector& height_sections) { - std::vector abs_sections; +static std::vector ToAbsoluteSection( + const std::vector& height_sections) { + std::vector abs_sections; abs_sections.resize(height_sections.size()); abs_sections[0] = 0; for (size_t i = 1; i < height_sections.size(); ++i) { @@ -47,7 +47,7 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); auto outs = ctx.MultiOutput("Out"); - auto height_sections = ctx.Attr>("height_sections"); + auto height_sections = ctx.Attr>("height_sections"); auto abs_sections = ToAbsoluteSection(height_sections); diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 34dbac2ab8dcc9bd2b91e2daa2f42806057f5f56..416da3b54aa58e92f9e9fc0112680f07d11159ff 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -82,14 +82,15 @@ class SumOp : public framework::OperatorWithKernel { if (x_vars[0]->IsType()) { int dtype = -1; for (auto& x_var : x_vars) { - auto& lod_tensor = x_var->Get(); - if (lod_tensor.numel() == 0) { + // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. + auto tensor = framework::GetTensorFromVar(x_var); + if (tensor->numel() == 0) { continue; } if (dtype == -1) { - dtype = framework::ToDataType(lod_tensor.type()); + dtype = framework::ToDataType(tensor->type()); } else { - PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(lod_tensor.type())); + PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type())); } } PADDLE_ENFORCE_NE(dtype, -1, diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 11987c61aebaad00f8a71f1b909c83c44ddc8b0e..f6e12dfc76c6ce73f10e707387f6a9cedacde3c8 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -83,79 +83,54 @@ class SumKernel : public framework::OpKernel { } } } else if (out_var->IsType()) { - std::unique_ptr in0; - if (in_place) { - // If is in_place, we store the input[0] to in0 - auto &in_sel0 = in_vars[0]->Get(); - auto &rows = in_sel0.rows(); -#ifdef PADDLE_WITH_CUDA - std::vector rows_in_cpu; - rows_in_cpu.reserve(rows.size()); - for (auto item : rows) { - rows_in_cpu.push_back(item); - } - in0.reset(new framework::SelectedRows(rows_in_cpu, in_sel0.height())); -#else - in0.reset(new framework::SelectedRows(rows, in_sel0.height())); -#endif - in0->mutable_value()->ShareDataWith(in_sel0.value()); + if (in_place && in_vars.size() < 2) { + return; } - auto get_selected_row = [&](size_t i) -> const SelectedRows & { - if (i == 0 && in0) { - return *in0.get(); - } else { - return in_vars[i]->Get(); + std::vector inputs; + SelectedRows temp_in0; + + if (in_place) { + auto &in0 = in_vars[0]->Get(); + temp_in0.set_height(in0.height()); + temp_in0.set_rows(in0.rows()); + framework::TensorCopy(in0.value(), in0.place(), + context.device_context(), + temp_in0.mutable_value()); + inputs.push_back(&temp_in0); + for (size_t i = 1; i < in_vars.size(); ++i) { + auto &in = in_vars[i]->Get(); + if (in.rows().size() > 0) { + inputs.push_back(&in); + } + } + } else { + for (auto &in_var : in_vars) { + auto &in = in_var->Get(); + if (in.rows().size() > 0) { + inputs.push_back(&in_var->Get()); + } } - }; + } auto *out = context.Output("Out"); out->mutable_rows()->clear(); - auto *out_value = out->mutable_value(); - - // Runtime InferShape - size_t first_dim = 0; - for (size_t i = 0; i < in_num; i++) { - auto &sel_row = get_selected_row(i); - first_dim += sel_row.rows().size(); - } - std::vector in_dim; - for (size_t i = 0; i < in_num; i++) { - auto &sel_row = get_selected_row(i); - if (sel_row.rows().size() > 0) { - in_dim = framework::vectorize(sel_row.value().dims()); + bool has_data = false; + for (auto &in : inputs) { + if (in->rows().size() > 0) { + has_data = true; break; } } - if (in_dim.empty()) { - VLOG(3) << "WARNING: all the inputs are empty"; - in_dim = - framework::vectorize(get_selected_row(in_num - 1).value().dims()); + if (has_data) { + math::scatter::MergeAdd merge_add; + merge_add(context.template device_context(), inputs, + out); } else { - in_dim[0] = static_cast(first_dim); - } - - out_value->Resize(framework::make_ddim(in_dim)); - out_value->mutable_data(context.GetPlace()); - // if all the input sparse vars are empty, no need to - // merge these vars. - if (first_dim == 0UL) { - return; - } - - math::SelectedRowsAddTo functor; - - int64_t offset = 0; - for (size_t i = 0; i < in_num; i++) { - auto &sel_row = get_selected_row(i); - if (sel_row.rows().size() == 0) { - continue; - } - PADDLE_ENFORCE_EQ(out->height(), sel_row.height()); - functor(context.template device_context(), sel_row, - offset, out); - offset += sel_row.value().numel(); + // no data, just set a empty out tensor. + out->mutable_value()->mutable_data(framework::make_ddim({0}), + context.GetPlace()); } } else if (out_var->IsType()) { auto &out_array = *out_var->GetMutable(); diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index aa907595cb7cf165974caa69fe8eb0370471732d..e3132ae76f624f3338d749e4fcebbd0ecd7ffe79 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -29,7 +29,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { if (out_var->IsType()) { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { - auto shape = ctx.Attr>("shape"); + auto shape = ctx.Attr>("shape"); auto *selected_rows = out_var->GetMutable(); tensor = selected_rows->mutable_value(); tensor->Resize(framework::make_ddim(shape)); @@ -67,7 +67,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto &shape = ctx->Attrs().Get>("shape"); + auto &shape = ctx->Attrs().Get>("shape"); std::vector temp; temp.reserve(shape.size()); for (auto dim : shape) { @@ -94,7 +94,7 @@ This operator initializes a tensor with random values sampled from a uniform distribution. The random result is in set [min, max]. )DOC"); - AddAttr>("shape", "The shape of the output tensor"); + AddAttr>("shape", "The shape of the output tensor"); AddAttr("min", "Minimum value of uniform random. [default -1.0].") .SetDefault(-1.0f); AddAttr("max", "Maximun value of uniform random. [default 1.0].") diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index bbb692b0ddfc18e8a62c0d2a6bac88f9932f6704..2bb0ecc139f7096d1b61150e0a2d4fb095338749 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -48,7 +48,7 @@ class GPUUniformRandomKernel : public framework::OpKernel { if (out_var->IsType()) { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { - auto shape = context.Attr>("shape"); + auto shape = context.Attr>("shape"); tensor = out_var->GetMutable()->mutable_value(); tensor->Resize(framework::make_ddim(shape)); } else { diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index ab91ca5345047f3053eb8771e6a265d2a3011f85..2211e5504373b4a30e5fda0db22a41bdcd9f2421 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -116,21 +116,51 @@ void InitDevices(bool init_p2p, const std::vector devices) { platform::SetNumThreads(FLAGS_paddle_num_threads); #endif - if (platform::jit::MayIUse(platform::jit::avx512f)) { -#ifndef __AVX512F__ - LOG(WARNING) << "AVX512F is available, Please re-compile on local machine"; +#if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__) + if (platform::jit::MayIUse(platform::jit::avx)) { +#ifndef __AVX__ + LOG(WARNING) << "AVX is available, Please re-compile on local machine"; #endif } - if (platform::jit::MayIUse(platform::jit::avx2)) { -#ifndef __AVX2__ - LOG(WARNING) << "AVX2 is available, Please re-compile on local machine"; + +// Throw some informations when CPU instructions mismatch. +#define AVX_GUIDE(compiletime, runtime) \ + LOG(FATAL) \ + << "This version is compiled on higher instruction(" #compiletime \ + ") system, you may encounter illegal instruction error running on" \ + " your local CPU machine. Please reinstall the " #runtime \ + " version or compile from source code." + +#ifdef __AVX512F__ + if (!platform::jit::MayIUse(platform::jit::avx512f)) { + if (platform::jit::MayIUse(platform::jit::avx2)) { + AVX_GUIDE(AVX512, AVX2); + } else if (platform::jit::MayIUse(platform::jit::avx)) { + AVX_GUIDE(AVX512, AVX); + } else { + AVX_GUIDE(AVX512, NonAVX); + } + } #endif + +#ifdef __AVX2__ + if (!platform::jit::MayIUse(platform::jit::avx2)) { + if (platform::jit::MayIUse(platform::jit::avx)) { + AVX_GUIDE(AVX2, AVX); + } else { + AVX_GUIDE(AVX2, NonAVX); + } } - if (platform::jit::MayIUse(platform::jit::avx)) { -#ifndef __AVX__ - LOG(WARNING) << "AVX is available, Please re-compile on local machine"; #endif + +#ifdef __AVX__ + if (!platform::jit::MayIUse(platform::jit::avx)) { + AVX_GUIDE(AVX, NonAVX); } +#endif +#undef AVX_GUIDE + +#endif } void InitGLOG(const std::string &prog_name) { diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 3b22718a8c6f994dbc2dc3e7aaa19a7163f716ba..d3b0d4a22954c1d67dc9551b997dcffa0625cbeb 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -57,6 +57,18 @@ struct variant_caster> { auto caster = make_caster(); if (!load_success_ && caster.load(src, convert)) { load_success_ = true; + + if (std::is_same>::value) { + auto caster_ints = make_caster>(); + if (caster_ints.load(src, convert)) { + VLOG(4) << "This value are floats and int64_ts satisfy " + "simultaneously, will set it's type to " + "std::vector"; + value = cast_op>(caster_ints); + return true; + } + } + value = cast_op(caster); return true; } @@ -259,6 +271,8 @@ void BindOpDesc(pybind11::module *m) { pybind11::enum_(*m, "AttrType", "") .value("INT", pd::proto::AttrType::INT) .value("INTS", pd::proto::AttrType::INTS) + .value("LONG", pd::proto::AttrType::LONG) + .value("LONGS", pd::proto::AttrType::LONGS) .value("FLOAT", pd::proto::AttrType::FLOAT) .value("FLOATS", pd::proto::AttrType::FLOATS) .value("STRING", pd::proto::AttrType::STRING) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 339a7c98c6a2bba2cd46790cecc169ef447c63ce..5f15a29f4c3e9b1412912fe4723642d1ede60346 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -645,9 +645,13 @@ All parameter, weight, gradient are variables in Paddle. py::class_> pass(m, "Pass"); pass.def(py::init()) - .def("set_str", [](ir::Pass &self, const std::string &name, - const std::string &attr) { - self.Set(name, new std::string(attr)); + .def( + "set_str", + [](ir::Pass &self, const std::string &name, const std::string &attr) { + self.Set(name, new std::string(attr)); + }) + .def("set_int", [](ir::Pass &self, const std::string &name, int val) { + self.Set(name, new int(val)); }); py::class_> pb( diff --git a/paddle/fluid/train/demo/CMakeLists.txt b/paddle/fluid/train/demo/CMakeLists.txt index 78d6e5ff554b9cd9facae85be166a697e0b75337..eabb51d370aff709e289e1fc727aa2dbb551d82e 100644 --- a/paddle/fluid/train/demo/CMakeLists.txt +++ b/paddle/fluid/train/demo/CMakeLists.txt @@ -15,6 +15,7 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/zlib/include") @@ -27,6 +28,7 @@ link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") add_executable(demo_trainer demo_trainer.cc) @@ -62,5 +64,5 @@ target_link_libraries(demo_trainer ${ARCHIVE_END} ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf snappystream snappy z + glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 85493c10549c290330ed09b9f28accb7a980de6a..5a71382fb14b64989502c34d8ac0aa13c62bc7d0 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -95,9 +95,9 @@ function cmake_gen() { exit 1 fi fi - else + else if [ "$1" != "" ]; then - echo "using python abi: $1" + echo "using python abi: $1" if [ "$1" == "cp27-cp27m" ]; then export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} export PATH=/opt/python/cp27-cp27m/bin/:${PATH} @@ -119,7 +119,7 @@ function cmake_gen() { fi fi fi - + if [ "$SYSTEM" == "Darwin" ]; then WITH_DISTRIBUTE=${WITH_DISTRIBUTE:-ON} WITH_AVX=${WITH_AVX:-ON} @@ -127,7 +127,7 @@ function cmake_gen() { else INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} fi - + cat < 0.0 and tot_neg > 0.0 else 0.0 + + +class DetectionMAP(object): + """ + Calculate the detection mean average precision (mAP). + + The general steps are as follows: + 1. calculate the true positive and false positive according to the input + of detection and labels. + 2. calculate mAP value, support two versions: '11 point' and 'integral'. + + Please get more information from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 + + Args: + input (Variable): The detection results, which is a LoDTensor with shape + [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax]. + gt_label (Variable): The ground truth label index, which is a LoDTensor + with shape [N, 1]. + gt_box (Variable): The ground truth bounding box (bbox), which is a + LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax]. + gt_difficult (Variable|None): Whether this ground truth is a difficult + bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, + it means all the ground truth labels are not difficult bbox. + class_num (int): The class number. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all categories will be + considered, 0 by defalut. + overlap_threshold (float): The threshold for deciding true/false + positive, 0.5 by defalut. + evaluate_difficult (bool): Whether to consider difficult ground truth + for evaluation, True by defalut. This argument does not work when + gt_difficult is None. + ap_version (string): The average precision calculation ways, it must be + 'integral' or '11point'. Please check + https://sanchom.wordpress.com/tag/average-precision/ for details. + - 11point: the 11-point interpolated average precision. + - integral: the natural integral of the precision-recall curve. + + Examples: + .. code-block:: python + + exe = fluid.Executor(place) + map_evaluator = fluid.Evaluator.DetectionMAP(input, + gt_label, gt_box, gt_difficult) + cur_map, accum_map = map_evaluator.get_map_var() + fetch = [cost, cur_map, accum_map] + for epoch in PASS_NUM: + map_evaluator.reset(exe) + for data in batches: + loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch) + + In the above example: + + 'cur_map_v' is the mAP of current mini-batch. + 'accum_map_v' is the accumulative mAP of one pass. + """ + + def __init__(self, + input, + gt_label, + gt_box, + gt_difficult=None, + class_num=None, + background_label=0, + overlap_threshold=0.5, + evaluate_difficult=True, + ap_version='integral'): + + self.helper = LayerHelper('map_eval') + gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) + if gt_difficult: + gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype) + label = layers.concat([gt_label, gt_difficult, gt_box], axis=1) + else: + label = layers.concat([gt_label, gt_box], axis=1) + + # calculate mean average precision (mAP) of current mini-batch + map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + ap_version=ap_version) + + states = [] + states.append( + self._create_state( + dtype='int32', shape=None, suffix='accum_pos_count')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_true_pos')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_false_pos')) + var = self._create_state(dtype='int32', shape=[1], suffix='has_state') + self.helper.set_variable_initializer( + var, initializer=Constant(value=int(0))) + self.has_state = var + + # calculate accumulative mAP + accum_map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + has_state=self.has_state, + input_states=states, + out_states=states, + ap_version=ap_version) + + layers.fill_constant( + shape=self.has_state.shape, + value=1, + dtype=self.has_state.dtype, + out=self.has_state) + + self.cur_map = map + self.accum_map = accum_map + + def _create_state(self, suffix, dtype, shape): + """ + Create state variable. + Args: + suffix(str): the state suffix. + dtype(str|core.VarDesc.VarType): the state data type + shape(tuple|list): the shape of state + Returns: State variable + """ + state = self.helper.create_variable( + name="_".join([unique_name.generate(self.helper.name), suffix]), + persistable=True, + dtype=dtype, + shape=shape) + return state + + def get_map_var(self): + """ + Returns: mAP variable of current mini-batch and + accumulative mAP variable cross mini-batches. + """ + return self.cur_map, self.accum_map + + def reset(self, executor, reset_program=None): + """ + Reset metric states at the begin of each pass/user specified batch. + + Args: + executor(Executor): a executor for executing + the reset_program. + reset_program(Program|None): a single Program for reset process. + If None, will create a Program. + """ + + def _clone_var_(block, var): + assert isinstance(var, Variable) + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=var.persistable) + + if reset_program is None: + reset_program = Program() + with program_guard(main_program=reset_program): + var = _clone_var_(reset_program.current_block(), self.has_state) + layers.fill_constant( + shape=var.shape, value=0, dtype=var.dtype, out=var) + executor.run(reset_program) diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index 667db10d3ebdd24ddd9efbe2310ebb331e268ee2..4e1d1450dea85fe4eb3e68713250836e4beac992 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -120,6 +120,8 @@ class OpDescCreationMethod(object): new_attr.strings.extend(user_defined_attr) elif attr.type == framework_pb2.BOOLEANS: new_attr.bools.extend(user_defined_attr) + elif attr.type == framework_pb2.LONGS: + new_attr.longs.extend(user_defined_attr) elif attr.type == framework_pb2.INT_PAIRS: for p in user_defined_attr: pair = new_attr.int_pairs.add() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6ea280c733996436a6218cc6f759f2cf7c652ac9..7e2364a5a872cdd8cf590438cc081ab070db767d 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -14,6 +14,7 @@ from __future__ import print_function import re +import sys from collections import defaultdict from paddle.fluid.framework import Program, Variable, name_scope, default_main_program from . import framework @@ -32,7 +33,8 @@ __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', - 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer' + 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', + 'LarsMomentumOptimizer' ] @@ -105,7 +107,6 @@ class Optimizer(object): param = param_and_grad[0] param_lr = param.optimize_attr['learning_rate'] if type(param_lr) == Variable: - print("returns updated param lr ", param_lr) return param_lr else: if param_lr == 1.0: @@ -400,6 +401,91 @@ class MomentumOptimizer(Optimizer): return momentum_op +class LarsMomentumOptimizer(Optimizer): + """ + Momentum optimizer with LARS support + + The update equations are as follows: + + .. math:: + + & local\_learning\_rate = learning\_rate * lars\_coeff * \\ + \\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||} + + & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param) + + & param = param - velocity + + Args: + learning_rate (float|Variable): the learning rate used to update parameters. \ + Can be a float value or a Variable with one float value as data element. + momentum (float): momentum factor + lars_coeff (float): defines how much we trust the layer to change its weights. + lars_weight_decay (float): weight decay coefficient for decaying using LARS. + regularization: A Regularizer, such as + fluid.regularizer.L2DecayRegularizer. + name: A optional name prefix. + + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.LarsMomentum(learning_rate=0.2, momentum=0.1, lars_weight_decay=0.001) + optimizer.minimize(cost) + """ + _velocity_acc_str = "velocity" + + def __init__(self, + learning_rate, + momentum, + lars_coeff=0.001, + lars_weight_decay=0.0005, + regularization=None, + name=None): + assert learning_rate is not None + assert momentum is not None + super(LarsMomentumOptimizer, self).__init__( + learning_rate=learning_rate, + regularization=regularization, + name=name) + self.type = "lars_momentum" + self._momentum = momentum + self._lars_coeff = float(lars_coeff) + self._lars_weight_decay = float(lars_weight_decay) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Velocity": velocity_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={ + "ParamOut": param_and_grad[0], + "VelocityOut": velocity_acc + }, + attrs={ + "mu": self._momentum, + "lars_coeff": self._lars_coeff, + "lars_weight_decay": self._lars_weight_decay + }) + + return momentum_op + + class AdagradOptimizer(Optimizer): """ **Adaptive Gradient Algorithm (Adagrad)** @@ -1221,6 +1307,7 @@ DecayedAdagrad = DecayedAdagradOptimizer Adadelta = AdadeltaOptimizer RMSProp = RMSPropOptimizer Ftrl = FtrlOptimizer +LarsMomentum = LarsMomentumOptimizer class ModelAverage(Optimizer): diff --git a/python/paddle/fluid/tests/CMakeLists.txt b/python/paddle/fluid/tests/CMakeLists.txt index 7ad923d3321ec8a88b60d7f4f7777e12fad8faa6..d24417bbacb503d9ea70e68e7e0edb59e7dddbde 100644 --- a/python/paddle/fluid/tests/CMakeLists.txt +++ b/python/paddle/fluid/tests/CMakeLists.txt @@ -1,5 +1,3 @@ -set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") - file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/image_classification/CMakeLists.txt index 673c965b662a022739f8d489c331f4de9455a926..91c1d17eb5391ea37a41a886594cc71c6e6c56bd 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/CMakeLists.txt +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/CMakeLists.txt @@ -1,7 +1,19 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -# default test -foreach(src ${TEST_OPS}) - py_test(${src} SRCS ${src}.py) -endforeach() +if(NOT APPLE) + # default test + foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) + endforeach() +else() + foreach(src ${TEST_OPS}) + if(${src} STREQUAL "test_image_classification_vgg") + message(WARNING "These tests has been disabled in OSX for random fail: \n" ${src}) + elseif(${src} STREQUAL "test_image_classification_resnet") + message(WARNING "These tests has been disabled in OSX for random fail: \n" ${src}) + elseif() + py_test(${src} SRCS ${src}.py) + endif() + endforeach() +endif() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index cf54bc2dbe788f3757a7ef93f26156d118a0cd02..3a4128284d801512ff1c4863550a369476fca2b6 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -17,6 +17,10 @@ if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_listen_and_serv_op) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist) LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec) + LIST(REMOVE_ITEM TEST_OPS test_dist_ctr) + LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow) + LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_batch_merge) + LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification) endif(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 @@ -88,4 +92,6 @@ py_test_modules(test_parallel_executor_crf MODULES test_parallel_executor_crf SE py_test_modules(test_parallel_executor_fetch_feed MODULES test_parallel_executor_fetch_feed SERIAL) set_tests_properties(test_parallel_executor_fetch_feed PROPERTIES TIMEOUT 150) py_test_modules(test_parallel_executor_transformer MODULES test_parallel_executor_transformer SERIAL) -py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) +if(NOT APPLE) + py_test_modules(test_image_classification_resnet MODULES test_image_classification_resnet SERIAL) +endif() diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 877d21ae882ab4efb49beb6a846ab71a22c2aab7..01e9795d8b1beb67270f45fe7ba2819bf8c3be3e 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -95,7 +95,7 @@ class TestDistMnist2x2(TestDistRunnerBase): # Reader train_reader = paddle.batch( - paddle.dataset.mnist.train(), batch_size=batch_size) + paddle.dataset.mnist.test(), batch_size=batch_size) test_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size) opt.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..d386e75fd887a898f5a13e48e378e08ff6c99ea0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + +DTYPE = "float32" + + +def test_merge_reader(repeat_batch_size=8): + orig_reader = paddle.dataset.mnist.test() + record_batch = [] + b = 0 + for d in orig_reader(): + if b >= repeat_batch_size: + break + record_batch.append(d) + b += 1 + while True: + for d in record_batch: + yield d + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) + + # Reader + train_reader = paddle.batch(test_merge_reader, batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_lars.py b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py new file mode 100644 index 0000000000000000000000000000000000000000..977e17c37f7676ae81d9ab29b6b36089ccbeeacf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.LarsMomentumOptimizer( + learning_rate=0.001, momentum=0.9) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index ab44954811562b8f74e368a551e855948f90af87..27c67edf4f62dd3c5d396826348f8da4513667ba 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -1159,6 +1159,7 @@ def prepare_encoder(src_word, name=pos_enc_param_name, trainable=False, initializer=fluid.initializer.ConstantInitializer(0.001))) + src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc return layers.dropout( enc_input, diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 04924bec057e301bfb342a62bb4c1e0b3c3aff4c..0836518401faa425821ee908cff3af575d92a2b4 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -26,10 +26,11 @@ import argparse import paddle.fluid as fluid RUN_STEP = 10 +DEFAULT_BATCH_SIZE = 2 class TestDistRunnerBase(object): - def get_model(self, batch_size=2): + def get_model(self, batch_size=DEFAULT_BATCH_SIZE): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -48,8 +49,7 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): - - self.get_model(batch_size=2) + self.get_model(batch_size=args.batch_size) # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, fluid.default_main_program(), args.endpoints, @@ -65,7 +65,7 @@ class TestDistRunnerBase(object): def run_trainer(self, args): test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ - self.get_model(batch_size=2) + self.get_model(batch_size=args.batch_size) if args.mem_opt: fluid.memory_optimize(fluid.default_main_program(), skip_grads=True) @@ -92,6 +92,11 @@ class TestDistRunnerBase(object): strategy.allow_op_delay = False build_stra = fluid.BuildStrategy() + if args.batch_merge_repeat > 1: + pass_builder = build_stra._create_passes_from_strategy() + mypass = pass_builder.insert_pass( + len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") + mypass.set_int("num_repeats", args.batch_merge_repeat) if args.use_reduce: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce @@ -145,6 +150,9 @@ def runtime_main(test_class): parser.add_argument('--use_reduce', action='store_true') parser.add_argument( '--use_reader_alloc', action='store_true', required=False, default=True) + parser.add_argument('--batch_size', required=False, type=int, default=2) + parser.add_argument( + '--batch_merge_repeat', required=False, type=int, default=1) args = parser.parse_args() @@ -180,7 +188,7 @@ class TestDistBase(unittest.TestCase): self._pservers = 2 self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._find_free_port(), self._find_free_port()) - self._python_interp = "python" + self._python_interp = sys.executable self._sync_mode = True self._enforce_place = None self._mem_opt = False @@ -244,9 +252,18 @@ class TestDistBase(unittest.TestCase): (e, retry_times)) retry_times -= 1 - def _run_local(self, model, envs, check_error_log): + def _run_local(self, + model, + envs, + check_error_log=False, + batch_size=DEFAULT_BATCH_SIZE, + batch_merge_repeat=1): cmd = "%s %s --role trainer" % (self._python_interp, model) + if batch_size != DEFAULT_BATCH_SIZE: + cmd += " --batch_size %d" % batch_size + if batch_merge_repeat > 1: + cmd += " --batch_merge_repeat %d" % batch_merge_repeat if self.__use_cuda: cmd += " --use_cuda" diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index 3575fd07fc727bd6c6b07a19a60b1df6656ae9e2..b2d979729bc9b2546375cb657f78abe0d8c2dcc7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -18,14 +18,14 @@ import unittest from test_dist_base import TestDistBase +# FIXME(tangwei): sum op can not handle when inputs is empty. class TestDistCTR2x2(TestDistBase): def _setup_config(self): self._sync_mode = True self._enforce_place = "CPU" - -def test_dist_ctr(self): - self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) + def test_dist_ctr(self): + self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 94b66a40233be4378e1a003f01d9375d00794743..922dd838f8996adfc15afffcd44c1acca2bc14a9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -26,6 +26,15 @@ class TestDistMnist2x2(TestDistBase): self.check_with_place("dist_mnist.py", delta=1e-5) +class TestDistMnist2x2Lars(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + + def test_se_resnext(self): + self.check_with_place("dist_mnist_lars.py", delta=1e-5) + + class TestDistMnist2x2WithMemopt(TestDistBase): def _setup_config(self): self._sync_mode = True @@ -40,8 +49,7 @@ class TestDistMnistAsync(TestDistBase): self._sync_mode = False self._use_reduce = False - # FIXME(typhoonzero): fix async mode test later - def no_test_dist_train(self): + def test_dist_train(self): self.check_with_place("dist_mnist.py", delta=200) diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..22d4b7929033529c5cea60064e6d9de57eddeb8e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py @@ -0,0 +1,67 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import os + + +class TestDistMnist2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + + def test_dist_train(self): + self.check_with_place("dist_mnist_batch_merge.py", delta=1e-5) + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + # TODO(typhoonzero): should auto adapt GPU count on the machine. + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_cudnn_deterministic": "1", + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "7" + required_envs["GLOG_logtostderr"] = "1" + + no_merge_losses = self._run_local( + model_file, + required_envs, + check_error_log=check_error_log, + batch_size=4) + + batch_merge_losses = self._run_local( + model_file, + required_envs, + check_error_log=check_error_log, + batch_size=2, + batch_merge_repeat=2) + # Ensure both result have values. + self.assertGreater(len(no_merge_losses), 1) + self.assertEqual(len(no_merge_losses), len(batch_merge_losses)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py index c1e60dc9e420d11677468e0c62357437ecdf9e35..c0989ca709e100d8f147a08970b0e858c81ce09b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/test_dist_se_resnext.py @@ -40,8 +40,7 @@ class TestDistSeResneXt2x2Async(TestDistBase): self._sync_mode = False self._use_reader_alloc = False - #FIXME(typhoonzero): fix async mode later - def no_test_dist_train(self): + def test_dist_train(self): self.check_with_place("dist_se_resnext.py", delta=100) diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py index e1e6ef61090dfb439a3b43c4baf5ba88f61310ba..102a4dab05fe1adc6a503920714f50415b29dc19 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -42,7 +42,6 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): self._sync_mode = False self._enforce_place = "CPU" - #FIXME(typhoonzero): fix async tests later def no_test_simnet_bow(self): need_envs = { "IS_DISTRIBUTED": '0', @@ -79,8 +78,7 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): self._sync_mode = False self._enforce_place = "CPU" - #FIXME(typhoonzero): fix async tests later - def no_test_simnet_bow(self): + def test_simnet_bow(self): need_envs = { "IS_DISTRIBUTED": '0', "IS_SPARSE": '1', @@ -94,7 +92,6 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): # FIXME(tangwei): Learningrate variable is not created on pserver. -""" class TestDistSimnetBow2x2LookupTableSync(TestDistBase): def _setup_config(self): self._sync_mode = True @@ -147,7 +144,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase): delta=1e-5, check_error_log=False, need_envs=need_envs) -""" + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 54a1c68a37f6929890aab697b48d621e6effb7d8..c4511a98b0667ecccaa8f63b3064c4fc4e86cc78 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -480,7 +480,7 @@ class TestDistLookupTable(TestDistLookupTableBase): def transpiler_test_impl(self): pserver1, startup1 = self.get_pserver(self.pserver1_ep) - self.assertEqual(len(pserver1.blocks), 6) + self.assertEqual(len(pserver1.blocks), 5) # 0 listen_and_serv # 1 optimize for fc_w or fc_b adam self.assertEqual([op.type for op in pserver1.blocks[1].ops], @@ -491,26 +491,32 @@ class TestDistLookupTable(TestDistLookupTableBase): # 3 prefetch -> lookup_sparse_table for data0 self.assertEqual([op.type for op in pserver1.blocks[3].ops], ["lookup_sparse_table"]) - # 4 prefetch -> lookup_sparse_table for data1 - self.assertEqual([op.type for op in pserver1.blocks[4].ops], - ["lookup_sparse_table"]) - # 5 save table - self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + # 4 save table + self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"]) - trainer, _ = self.get_trainer() + trainer, trainer_startup = self.get_trainer() self.assertEqual(len(trainer.blocks), 1) ops = [ - 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', - 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', - 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', - 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', - 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', - 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', - 'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv', - 'fetch_barrier' + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', + 'sequence_pool', 'concat', 'mul', 'elementwise_add', + 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids', + 'send', 'send_barrier', 'recv', 'recv', 'fetch_barrier' ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) + startup_ops = [ + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', + 'fetch_barrier', 'fake_init' + ] + self.assertEqual([op.type for op in trainer_startup.blocks[0].ops], + startup_ops) + class TestAsyncLocalLookupTable(TestDistLookupTableBase): def net_conf(self): @@ -553,7 +559,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) - self.assertEqual(len(pserver1.blocks), 6) + self.assertEqual(len(pserver1.blocks), 5) # 0 listen_and_serv # 1 optimize for fc_w or fc_b adam self.assertEqual([op.type for op in pserver1.blocks[1].ops], @@ -563,22 +569,19 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): # 3 prefetch -> lookup_sparse_table for data0 self.assertEqual([op.type for op in pserver1.blocks[3].ops], ["lookup_sparse_table"]) - # 4 prefetch -> lookup_sparse_table for data1 - self.assertEqual([op.type for op in pserver1.blocks[4].ops], - ["lookup_sparse_table"]) - # 5 save table - self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) + # 4 save table + self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"]) trainer, _ = self.get_trainer(config) self.assertEqual(len(trainer.blocks), 1) ops = [ - 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', - 'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', - 'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', - 'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', - 'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', - 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', - 'sum', 'split_ids', 'send', 'recv', 'recv' + 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', + 'sequence_pool', 'concat', 'mul', 'elementwise_add', + 'cross_entropy', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad', + 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad', + 'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids', + 'send', 'recv', 'recv' ] self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) diff --git a/python/paddle/fluid/tests/unittests/test_fake_init_op.py b/python/paddle/fluid/tests/unittests/test_fake_init_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a62b7aed66b59940b4ba654d98479e3e35c7b78b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fake_init_op.py @@ -0,0 +1,52 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle.fluid.core as core +from paddle.fluid.op import Operator + + +class TestFakeInitOpSelectedRows(unittest.TestCase): + def check_with_place(self, place, is_selected_rows): + scope = core.Scope() + + out_var_name = 'Out' + if is_selected_rows: + out_tensor = scope.var(out_var_name).get_selected_rows().get_tensor( + ) + else: + out_tensor = scope.var(out_var_name).get_tensor() + + var_shape = [4, 784] + + # create and run fake_init_op + fake_init_op = Operator("fake_init", Out=out_var_name, shape=var_shape) + fake_init_op.run(scope, place) + + self.assertEqual(var_shape, out_tensor._get_dims()) + + def test_fake_init_selected_rows(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + for place in places: + for is_selected_rows in [True, False]: + self.check_with_place(place, is_selected_rows) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1130ea39c42204283885ab1072a52db8c22f8b2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -0,0 +1,57 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestScaleOp(OpTest): + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': (self.in_seq, self.lod)} + self.attrs = {'num_hash': 4, 'mod_by': 10000} + self.outputs = {'Out': (self.out_seq, self.lod)} + + def init_test_case(self): + np.random.seed = 1 + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + # self.out_seq = np.ones([30, 4, 1], dtype=np.int32) + self.out_seq = [ + [[9662], [9217], [1129], [8487]], [[9662], [9217], [1129], [8487]], + [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], + [[9407], [6715], [6949], [8094]], [[8473], [694], [5142], [2479]], + [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], + [[4372], [9456], [8204], [6695]], [[6897], [3218], [2013], [1241]], + [[8473], [694], [5142], [2479]], [[4372], [9456], [8204], [6695]], + [[4372], [9456], [8204], [6695]], [[8473], [694], [5142], [2479]], + [[9407], [6715], [6949], [8094]], [[9369], [4525], [8935], [9210]], + [[4372], [9456], [8204], [6695]], [[4372], [9456], [8204], [6695]], + [[9369], [4525], [8935], [9210]], [[6897], [3218], [2013], [1241]], + [[9038], [7951], [5953], [8657]], [[9407], [6715], [6949], [8094]], + [[9662], [9217], [1129], [8487]], [[9369], [4525], [8935], [9210]], + [[9038], [7951], [5953], [8657]], [[9662], [9217], [1129], [8487]], + [[9369], [4525], [8935], [9210]], [[1719], [5986], [9919], [3421]], + [[4372], [9456], [8204], [6695]], [[9038], [7951], [5953], [8657]] + ] + self.out_seq = np.array(self.out_seq) + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py index 26ce7024117162e8bad403a9d8b8518c27578c83..b109e4ea62669c735128f4824eb9d02ad43900e0 100644 --- a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py @@ -22,15 +22,28 @@ from op_test import OpTest class TestMergeIdsOp(OpTest): def setUp(self): self.op_type = "merge_ids" - ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') - x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32') - x1 = np.array([]).astype('float32') - x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6], - [0.5, 0.6]]).astype('float32') - out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3], - [0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32') - self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]} - self.outputs = {'Out': out} + ids1 = np.array([[0], [2], [5], [6]]).astype('int64') + ids2 = np.array([[0], [2], [2], [3]]).astype('int64') + + rows1 = np.array([[0], [2]]).astype('int64') + rows2 = np.array([[3], [5]]).astype('int64') + rows3 = np.array([[6]]).astype('int64') + + x0 = np.array([[0.1, 0.2], [0.2, 0.3]]).astype('float32') + x1 = np.array([[0.3, 0.4], [0.4, 0.5]]).astype('float32') + x2 = np.array([[0.5, 0.6]]).astype('float32') + + out1 = np.array( + [[0.1, 0.2], [0.2, 0.3], [0.4, 0.5], [0.5, 0.6]]).astype('float32') + out2 = np.array( + [[0.1, 0.2], [0.2, 0.3], [0.2, 0.3], [0.3, 0.4]]).astype('float32') + + self.inputs = { + 'Ids': [('ids1', ids1), ('ids2', ids2)], + "Rows": [('rows1', rows1), ('rows2', rows2), ('rows3', rows3)], + "X": [('x0', x0), ('x1', x1), ('x2', x2)] + } + self.outputs = {'Out': [('out1', out1), ('out2', out2)]} def test_check_output(self): self.check_output() diff --git a/python/paddle/fluid/tests/unittests/test_metrics.py b/python/paddle/fluid/tests/unittests/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..ec27884cae2b0462951f6597b1b83e58d1c8af5d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_metrics.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +from paddle.fluid.framework import Program, program_guard + + +class TestMetricsDetectionMap(unittest.TestCase): + def test_detection_map(self): + program = fluid.Program() + with program_guard(program): + detect_res = fluid.layers.data( + name='detect_res', + shape=[10, 6], + append_batch_size=False, + dtype='float32') + label = fluid.layers.data( + name='label', + shape=[10, 1], + append_batch_size=False, + dtype='float32') + box = fluid.layers.data( + name='bbox', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + map_eval = fluid.metrics.DetectionMAP( + detect_res, label, box, class_num=21) + cur_map, accm_map = map_eval.get_map_var() + self.assertIsNotNone(cur_map) + self.assertIsNotNone(accm_map) + print(str(program)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index a3d89610b40ff9bd5002e843f8667ada87e67981..cf4346cf2e7a099334ec273546901a91d0ad925d 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -90,6 +90,45 @@ class TestMomentumOp2(OpTest): self.check_output() +class TestLarsMomentumOp(OpTest): + def setUp(self): + self.op_type = "lars_momentum" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + mu = 0.0001 + lars_coeff = 0.001 + lars_weight_decay = 0.0005 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = { + 'mu': mu, + 'lars_coeff': lars_coeff, + 'lars_weight_decay': lars_weight_decay + } + + pnorm = np.sqrt(np.square(param).sum()) + gnorm = np.sqrt(np.square(grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * param) + velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay * + param) + param_out = param - velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def test_check_output(self): + self.check_output() + + class TestSparseMomentumOp(unittest.TestCase): def setUp(self): self.use_nesterov = False diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py index 4c3d0258980fd8595704a65219deb520b96e222e..d674dad2293921c06135b4ee528538d266cb2904 100644 --- a/python/paddle/fluid/tests/unittests/test_split_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -25,18 +25,21 @@ from paddle.fluid.op import Operator class TestSplitIdsOp(OpTest): def setUp(self): self.op_type = "split_ids" - ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + ids1 = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + ids2 = np.array([[6], [2], [3], [3], [5], [2], [6]]).astype('int64') + ids3 = np.array([[2], [2], [2], [3], [5], [5], [6]]).astype('int64') + out0 = np.array([[0], [3], [6]]).astype('int64') out1 = np.array([[]]).astype('int64') - out2 = np.array([[2], [2], [5], [5]]).astype('int64') - self.inputs = {'Ids': ids} + out2 = np.array([[2], [5]]).astype('int64') + self.inputs = {'Ids': [('ids1', ids1), ('ids2', ids2), ('ids3', ids3)]} self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]} def test_check_output(self): self.check_output() -class TestSpliteIds(unittest.TestCase): +class TestSplitSelectedRows(unittest.TestCase): def get_places(self): places = [core.CPUPlace()] return places diff --git a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py index 41a5ee59ea523b1f6c5015974a12c526e883fa35..50204b8a77c187aa695da83860960566448d290f 100644 --- a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py @@ -99,7 +99,6 @@ class TestSpliteSelectedRows(unittest.TestCase): out0_grad.set_height(height) out0_grad_tensor = out0_grad.get_tensor() np_array = np.ones((len(rows0), row_numel)).astype("float32") - np_array[0, 0] = 2.0 out0_grad_tensor.set(np_array, place) out1_grad = scope.var("out1@GRAD").get_selected_rows() @@ -108,7 +107,6 @@ class TestSpliteSelectedRows(unittest.TestCase): out1_grad.set_height(height) out1_grad_tensor = out1_grad.get_tensor() np_array = np.ones((len(rows1), row_numel)).astype("float32") - np_array[0, 1] = 4.0 out1_grad_tensor.set(np_array, place) x_grad = scope.var("X@GRAD").get_selected_rows() @@ -121,11 +119,13 @@ class TestSpliteSelectedRows(unittest.TestCase): grad_op.run(scope, place) - self.assertEqual(x_grad.rows(), rows0 + rows1) + merged_rows = set(rows0 + rows1) + self.assertEqual(set(x_grad.rows()), set(rows0 + rows1)) self.assertEqual(x_grad.height(), height) + print(np.array(x_grad.get_tensor())) self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0]) - self.assertAlmostEqual(4.0, np.array(x_grad.get_tensor())[2, 1]) + self.assertAlmostEqual(1.0, np.array(x_grad.get_tensor())[2, 1]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 74797bb65678404b7b35d06eecc7f9a12b2a346e..e20418ff1c8d21f3a3e4ba15ff2aa9d54f37f4b2 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -45,16 +45,30 @@ class TestSumOp(OpTest): class TestSelectedRowsSumOp(OpTest): - def check_with_place(self, place): - scope = core.Scope() - self.check_input_and_optput(scope, place, True, True, True) - self.check_input_and_optput(scope, place, False, True, True) - self.check_input_and_optput(scope, place, False, False, True) - self.check_input_and_optput(scope, place, False, False, False) + def check_with_place(self, place, inplace): + self.height = 10 + self.row_numel = 12 + self.rows = [0, 1, 2, 3, 4, 5, 6] + + self.check_input_and_optput(core.Scope(), place, inplace, True, True, + True) + self.check_input_and_optput(core.Scope(), place, inplace, False, True, + True) + self.check_input_and_optput(core.Scope(), place, inplace, False, False, + True) + self.check_input_and_optput(core.Scope(), place, inplace, False, False, + False) + + def _get_array(self, row_num, row_numel): + array = np.ones((row_num, row_numel)).astype("float32") + for i in range(row_num): + array[i] *= i + return array def check_input_and_optput(self, scope, place, + inplace, w1_has_data=False, w2_has_data=False, w3_has_data=False): @@ -64,35 +78,43 @@ class TestSelectedRowsSumOp(OpTest): self.create_selected_rows(scope, place, "W3", w3_has_data) # create Out Variable - out = scope.var('Out').get_selected_rows() + if inplace: + out_var_name = "W1" + else: + out_var_name = "Out" + out = scope.var(out_var_name).get_selected_rows() # create and run sum operator - sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') + sum_op = Operator("sum", X=["W1", "W2", "W3"], Out=out_var_name) sum_op.run(scope, place) has_data_w_num = 0 - for w in [w1_has_data, w2_has_data, w3_has_data]: - if not w: + for has_data in [w1_has_data, w2_has_data, w3_has_data]: + if has_data: has_data_w_num += 1 - self.assertEqual(7 * has_data_w_num, len(out.rows())) + if has_data_w_num > 0: + self.assertEqual(len(out.rows()), 7) + self.assertTrue( + np.array_equal( + np.array(out.get_tensor()), + self._get_array(len(self.rows), self.row_numel) * + has_data_w_num)) + else: + self.assertEqual(len(out.rows()), 0) - def create_selected_rows(self, scope, place, var_name, isEmpty): + def create_selected_rows(self, scope, place, var_name, has_data): # create and initialize W Variable - if not isEmpty: - rows = [0, 1, 2, 3, 4, 5, 6] - row_numel = 12 + if has_data: + rows = self.rows else: rows = [] - row_numel = 12 var = scope.var(var_name) w_selected_rows = var.get_selected_rows() - w_selected_rows.set_height(len(rows)) + w_selected_rows.set_height(self.height) w_selected_rows.set_rows(rows) - w_array = np.ones((len(rows), row_numel)).astype("float32") - for i in range(len(rows)): - w_array[i] *= i + w_array = self._get_array(len(rows), self.row_numel) w_tensor = w_selected_rows.get_tensor() w_tensor.set(w_array, place) @@ -100,9 +122,11 @@ class TestSelectedRowsSumOp(OpTest): def test_w_is_selected_rows(self): places = [core.CPUPlace()] - # currently only support CPU + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) for place in places: - self.check_with_place(place) + for inplace in [True, False]: + self.check_with_place(place, inplace) if __name__ == "__main__": diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 28d7df8e45d4f3209fb6205b76f8d1be8e73392b..8daac0f43b41b9497812a07fa2f96bffb727413d 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -475,6 +475,26 @@ class DistributeTranspiler(object): delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), lr_ops) + # delete table init op + if self.has_distributed_lookup_table: + table_var = self.startup_program.global_block().vars[ + self.table_name] + table_param_init_op = [] + for op in self.startup_program.global_block().ops: + if self.table_name in op.output_arg_names: + table_param_init_op.append(op) + init_op_num = len(table_param_init_op) + if init_op_num != 1: + raise ValueError("table init op num should be 1, now is " + str( + init_op_num)) + table_init_op = table_param_init_op[0] + self.startup_program.global_block().append_op( + type="fake_init", + inputs={}, + outputs={"Out": table_var}, + attrs={"shape": table_init_op.attr('shape')}) + delete_ops(self.startup_program.global_block(), table_param_init_op) + self.origin_program.__str__() if wait_port: @@ -713,7 +733,7 @@ in a single call.") for _, op in enumerate(self.optimize_ops): # optimizer is connected to itself if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \ - op not in global_ops: + op not in global_ops: log("append opt op: ", op.type, op.input_arg_names, merged_var) __append_optimize_op__(op, per_opt_block, @@ -1034,15 +1054,11 @@ to transpile() call.") def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op - # self.all_prefetch_input_vars = - # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1] - # [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]] + self.all_in_ids_vars = [] self.all_prefetch_input_vars = [] - - # self.all_prefetch_input_vars = - # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1] - # [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]] self.all_prefetch_output_vars = [] + self.all_out_emb_vars = [] + lookup_table_op_index = -1 continue_search_lookup_table_op = True while continue_search_lookup_table_op: @@ -1052,72 +1068,68 @@ to transpile() call.") if op.type == LOOKUP_TABLE_TYPE: continue_search_lookup_table_op = True - lookup_table_op_index = list(all_ops).index(op) + lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list( + all_ops).index(op) ids_name = op.input("Ids") out_name = op.output("Out") ids_var = program.global_block().vars[ids_name[0]] - prefetch_input_vars = self._create_splited_vars( - source_var=ids_var, - block=program.global_block(), - tag="_prefetch_in_") - self.all_prefetch_input_vars.append(prefetch_input_vars) + self.all_in_ids_vars.append(ids_var) out_var = program.global_block().vars[out_name[0]] - prefetch_output_vars = self._create_splited_vars( - source_var=out_var, - block=program.global_block(), - tag="_prefetch_out_") - self.all_prefetch_output_vars.append(prefetch_output_vars) - - # insert split_ids_op - program.global_block()._insert_op( - index=lookup_table_op_index, - type="split_ids", - inputs={ - 'Ids': [ - program.global_block().vars[varname] - for varname in ids_name - ] - }, - outputs={"Out": prefetch_input_vars}) - - # insert prefetch_op - program.global_block()._insert_op( - index=lookup_table_op_index + 1, - type="prefetch", - inputs={'X': prefetch_input_vars}, - outputs={"Out": prefetch_output_vars}, - attrs={ - "epmap": pserver_endpoints, - # FIXME(qiao) temporarily disable this config because prefetch - # is not act as other rpc op, it's more like a forward op - # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) - - # insert concat_op - program.global_block()._insert_op( - index=lookup_table_op_index + 2, - type="merge_ids", - inputs={ - 'Ids': [ - program.global_block().vars[varname] - for varname in ids_name - ], - 'X': prefetch_output_vars - }, - outputs={ - "Out": [ - program.global_block().vars[varname] - for varname in out_name - ] - }) + self.all_out_emb_vars.append(out_var) # delete lookup_table_op delete_ops(program.global_block(), [op]) # break for loop break + for index in range(len(self.pserver_endpoints)): + in_var = program.global_block().create_var( + name=str("prefetch_compress_in_tmp_" + str(index)), + type=self.all_in_ids_vars[0].type, + shape=self.all_in_ids_vars[0].shape, + dtype=self.all_in_ids_vars[0].dtype) + self.all_prefetch_input_vars.append(in_var) + + out_var = program.global_block().create_var( + name=str("prefetch_compress_out_tmp_" + str(index)), + type=self.all_out_emb_vars[0].type, + shape=self.all_out_emb_vars[0].shape, + dtype=self.all_out_emb_vars[0].dtype) + self.all_prefetch_output_vars.append(out_var) + + # insert split_ids_op + program.global_block()._insert_op( + index=lookup_table_op_index, + type="split_ids", + inputs={'Ids': self.all_in_ids_vars}, + outputs={"Out": self.all_prefetch_input_vars}) + + # insert prefetch_op + program.global_block()._insert_op( + index=lookup_table_op_index + 1, + type="prefetch", + inputs={'X': self.all_prefetch_input_vars}, + outputs={"Out": self.all_prefetch_output_vars}, + attrs={ + "epmap": pserver_endpoints, + # FIXME(qiao) temporarily disable this config because prefetch + # is not act as other rpc op, it's more like a forward op + # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + # insert concat_op + program.global_block()._insert_op( + index=lookup_table_op_index + 2, + type="merge_ids", + inputs={ + 'Ids': self.all_in_ids_vars, + 'Rows': self.all_prefetch_input_vars, + 'X': self.all_prefetch_output_vars + }, + outputs={"Out": self.all_out_emb_vars}) + def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_op to send gradient to pservers @@ -1134,7 +1146,8 @@ to transpile() call.") inputs={ 'Ids': [program.global_block().vars[table_grad_name]] }, - outputs={"Out": self.trainer_side_table_grad_list}) + outputs={"Out": self.trainer_side_table_grad_list}, + attrs={RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE}) program.global_block()._insert_op( index=op_index + 2, type="send", @@ -1160,32 +1173,31 @@ to transpile() call.") # STEP: create prefetch block table_var = pserver_program.global_block().vars[self.table_name] prefetch_var_name_to_block_id = [] - for index in range(len(self.all_prefetch_input_vars)): - prefetch_block = pserver_program._create_block(optimize_block.idx) - trainer_ids = self.all_prefetch_input_vars[index][pserver_index] - pserver_ids = pserver_program.global_block().create_var( - name=trainer_ids.name, - type=trainer_ids.type, - shape=trainer_ids.shape, - dtype=trainer_ids.dtype) - trainer_out = self.all_prefetch_output_vars[index][pserver_index] - pserver_out = pserver_program.global_block().create_var( - name=trainer_out.name, - type=trainer_out.type, - shape=trainer_out.shape, - dtype=trainer_out.dtype) - prefetch_block.append_op( - type="lookup_sparse_table", - inputs={'Ids': pserver_ids, - "W": table_var}, - outputs={"Out": pserver_out}, - attrs={ - "is_sparse": True, # has no effect on lookup_table op - "is_distributed": True, - "padding_idx": -1 - }) - prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str( - prefetch_block.idx)) + prefetch_block = pserver_program._create_block(optimize_block.idx) + trainer_ids = self.all_prefetch_input_vars[pserver_index] + pserver_ids = pserver_program.global_block().create_var( + name=trainer_ids.name, + type=trainer_ids.type, + shape=trainer_ids.shape, + dtype=trainer_ids.dtype) + trainer_out = self.all_prefetch_output_vars[pserver_index] + pserver_out = pserver_program.global_block().create_var( + name=trainer_out.name, + type=trainer_out.type, + shape=trainer_out.shape, + dtype=trainer_out.dtype) + prefetch_block.append_op( + type="lookup_sparse_table", + inputs={'Ids': pserver_ids, + "W": table_var}, + outputs={"Out": pserver_out}, + attrs={ + "is_sparse": True, # has no effect on lookup_table op + "is_distributed": True, + "padding_idx": -1 + }) + prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str( + prefetch_block.idx)) return prefetch_var_name_to_block_id def _create_table_optimize_block(self, pserver_index, pserver_program, @@ -1364,16 +1376,6 @@ to transpile() call.") program.global_block()._sync_with_cpp() return var_mapping - def _create_splited_vars(self, source_var, block, tag): - return [ - block.create_var( - name=str(source_var.name + tag + str(index)), - type=source_var.type, - shape=source_var.shape, - dtype=source_var.dtype) - for index in range(len(self.pserver_endpoints)) - ] - def _clone_var(self, block, var, persistable=True): return block.create_var( name=var.name, @@ -1431,7 +1433,7 @@ to transpile() call.") elif op_type == "adamax": if varkey in ["Moment", "InfNorm"]: return param_shape - elif op_type == "momentum": + elif op_type in ["momentum", "lars_momentum"]: if varkey == "Velocity": return param_shape elif op_type == "rmsprop": @@ -1442,6 +1444,10 @@ to transpile() call.") return param_shape elif op_type == "sgd": pass + else: + raise ValueError( + "Not supported optimizer for distributed training: %s" % + op_type) return orig_shape def _get_varname_parts(self, varname): diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 861bb5fae5d7a8561ded1f547fbb86ae1e1a073e..c9f1be934773cc28f026f2b867b9e3a4f7aa8472 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -171,7 +171,7 @@ class ControlFlowGraph(object): self._live_out[i] |= self._live_in[s] self._live_in[i] = self._uses[i] | ( self._live_out[i] - self._defs[i]) - if live_in[i] != self._live_in[i]: + if live_in[i] != set(self._live_in[i]): for d in self._presuccessors[i]: worklist.append(d) @@ -321,8 +321,7 @@ class ControlFlowGraph(object): if not compare_shape(x_shape, cache_shape, level): continue - # TODO(qijun): actually, we should compare - # dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] + # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype] if x_dtype != cache_dtype: continue @@ -487,7 +486,6 @@ def memory_optimize(input_program, skip_opt_set = grad_set else: skip_opt_set.update(grad_set) - cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index 5de6f966a038543ffffdf955251f587e3eb15cad..db6fe2d5fff4ed1617d793faee23f01395841768 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from plot import Ploter +from .plot import Ploter __all__ = ['dump_config', 'Ploter']