diff --git a/.gitignore b/.gitignore index 90138f996cf9cacc3c1cbff0cf2600eefca3f305..fa0c8882606b76ac71b43dcde7e1df6770c46c31 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ third_party/ build_* # clion workspace. cmake-build-* +model_test diff --git a/CMakeLists.txt b/CMakeLists.txt index 6aa2e1715b92d73aa4e5e97d5e52ffac51451d80..d3379a663db4613e529cdba4ce22111765ff59cc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) option(WITH_INFERENCE "Compile fluid inference library" ON) +option(ON_INFER "Turn on inference optimization." OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) @@ -179,6 +180,7 @@ include(external/eigen) # download eigen3 include(external/pybind11) # download pybind11 include(external/cares) include(external/cub) +include(external/xxhash) # download xxhash if (NOT WIN32) # there is no official support of snappystream, warpctc, nccl, cupti in windows @@ -301,3 +303,8 @@ if(WITH_DOC) find_python_module(recommonmark REQUIRED) add_subdirectory(doc) endif() + +if (ON_INFER) + message(WARNING "On inference mode, will take place some specific optimization.") + add_definitions(-DPADDLE_ON_INFERENCE) +endif() diff --git a/Dockerfile b/Dockerfile index 738bba9bc2e1ab19709722fe04f1490b1b13bd4f..c8b9eed6d60e5d3b32fc14c0c7af80a785145d1b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -75,14 +75,14 @@ RUN pip3 install -U wheel && \ pip3 install -U docopt PyYAML sphinx==1.5.6 && \ pip3 install sphinx-rtd-theme==0.1.9 recommonmark && \ easy_install -U pip && \ - pip install -U wheel && \ + pip install -U pip setuptools wheel && \ pip install -U docopt PyYAML sphinx==1.5.6 && \ pip install sphinx-rtd-theme==0.1.9 recommonmark -RUN pip3 install pre-commit 'ipython==5.3.0' && \ +RUN pip3 install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip3 install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip3 install opencv-python && \ - pip install pre-commit 'ipython==5.3.0' && \ + pip install 'pre-commit==1.10.4' 'ipython==5.3.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install opencv-python diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py index 9540900b112f54594bbfdbc8d7cd3b6e1f5269dd..ff616ddbb2cb1cb7f348d6d164815823b08b7629 100644 --- a/benchmark/fluid/args.py +++ b/benchmark/fluid/args.py @@ -142,5 +142,10 @@ def parse_args(): choices=['reduce', 'all_reduce'], default='all_reduce', help='Specify the reduce strategy, can be reduce, all_reduce') + parser.add_argument( + '--fuse_broadcast_op', + action='store_true', + help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.' + ) args = parser.parse_args() return args diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index ddd9fe809853a830ca676cc98f1819f683866def..5f3ce300acc44ad8d2898c27296b866c403f3cc8 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, else: build_strategy.reduce_strategy = fluid.BuildStrategy( ).ReduceStrategy.AllReduce + build_strategy.fuse_broadcast_op = args.fuse_broadcast_op avg_loss = train_args[0] @@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, if args.use_fake_data or args.use_reader_op: try: - fetch_ret = exe.run(fetch_list) except fluid.core.EOFException as eof: break diff --git a/cmake/external/xxhash.cmake b/cmake/external/xxhash.cmake new file mode 100644 index 0000000000000000000000000000000000000000..4deaab7545c20002fedcad1cca6df54fe9783eb0 --- /dev/null +++ b/cmake/external/xxhash.cmake @@ -0,0 +1,46 @@ +INCLUDE(ExternalProject) + +set(XXHASH_SOURCE_DIR ${THIRD_PARTY_PATH}/xxhash) +set(XXHASH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/xxhash) +set(XXHASH_INCLUDE_DIR "${XXHASH_INSTALL_DIR}/include") + +IF(WITH_STATIC_LIB) + SET(BUILD_CMD make lib) +ELSE() + SET(BUILD_CMD sed -i "s/-Wstrict-prototypes -Wundef/-Wstrict-prototypes -Wundef -fPIC/g" ${XXHASH_SOURCE_DIR}/src/extern_xxhash/Makefile && make lib) +ENDIF() + +ExternalProject_Add( + extern_xxhash + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/Cyan4973/xxHash" + GIT_TAG "v0.6.5" + PREFIX ${XXHASH_SOURCE_DIR} + DOWNLOAD_NAME "xxhash" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_IN_SOURCE 1 + PATCH_COMMAND + BUILD_COMMAND ${BUILD_CMD} + INSTALL_COMMAND export PREFIX=${XXHASH_INSTALL_DIR}/ && make install + TEST_COMMAND "" +) + +set(XXHASH_LIBRARIES "${XXHASH_INSTALL_DIR}/lib/libxxhash.a") +INCLUDE_DIRECTORIES(${XXHASH_INCLUDE_DIR}) + +add_library(xxhash STATIC IMPORTED GLOBAL) +set_property(TARGET xxhash PROPERTY IMPORTED_LOCATION ${XXHASH_LIBRARIES}) +include_directories(${XXHASH_INCLUDE_DIR}) +add_dependencies(xxhash extern_xxhash) + +LIST(APPEND external_project_dependencies xxhash) + +IF(WITH_C_API) + INSTALL(DIRECTORY ${XXHASH_INCLUDE_DIR} DESTINATION third_party/xxhash) + IF(ANDROID) + INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib/${ANDROID_ABI}) + ELSE() + INSTALL(FILES ${XXHASH_LIBRARIES} DESTINATION third_party/xxhash/lib) + ENDIF() +ENDIF() diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 67cca09b64c1ed7a503a886e78347d786eae0de7..1047b6f998a74e42114b9deab4f0e7ba1af36835 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -14,6 +14,9 @@ # make package for paddle fluid shared and static library function(copy TARGET) + if (NOT ON_INFER) + message(WARNING "Turn on the ON_INFER flag when building inference_lib only.") + endif() set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DSTS DEPS) @@ -31,7 +34,7 @@ function(copy TARGET) foreach(index RANGE ${len}) list(GET copy_lib_SRCS ${index} src) list(GET copy_lib_DSTS ${index} dst) - add_custom_command(TARGET ${TARGET} PRE_BUILD + add_custom_command(TARGET ${TARGET} PRE_BUILD COMMAND mkdir -p "${dst}" COMMAND cp -r "${src}" "${dst}" COMMENT "copying ${src} -> ${dst}") @@ -67,6 +70,13 @@ copy(boost_lib DEPS boost ) +set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/xxhash") +copy(xxhash_lib + SRCS ${XXHASH_INCLUDE_DIR} ${XXHASH_LIBRARIES} + DSTS ${dst_dir} ${dst_dir}/lib + DEPS xxhash +) + if(NOT PROTOBUF_FOUND) set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/protobuf") copy(protobuf_lib @@ -186,7 +196,7 @@ copy(cmake_cache DSTS ${FLUID_INSTALL_DIR}) # This command generates a complete fluid library for both train and inference -add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) +add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep}) # Following commands generate a inference-only fluid library # third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR} diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 19ef23cdfa90912ff6fbd050a685d10861d909d2..2b8b82e74fc49d454b5331460acbffd0e9404fb5 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -86,7 +86,7 @@ paddle.fluid.layers.reduce_prod ArgSpec(args=['input', 'dim', 'keep_dim', 'name' paddle.fluid.layers.sequence_first_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_last_step ArgSpec(args=['input'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.sequence_slice ArgSpec(args=['input', 'offset', 'length', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name'], varargs=None, keywords=None, defaults=(False, None, None)) +paddle.fluid.layers.dropout ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed', 'name', 'dropout_implementation'], varargs=None, keywords=None, defaults=(False, None, None, 'downgrade_in_infer')) paddle.fluid.layers.split ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)) paddle.fluid.layers.ctc_greedy_decoder ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)) @@ -107,7 +107,7 @@ paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) -paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None)) +paddle.fluid.layers.reshape ArgSpec(args=['x', 'shape', 'actual_shape', 'act', 'inplace', 'name'], varargs=None, keywords=None, defaults=(None, None, False, None)) paddle.fluid.layers.squeeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.unsqueeze ArgSpec(args=['input', 'axes', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lod_reset ArgSpec(args=['x', 'y', 'target_lod'], varargs=None, keywords=None, defaults=(None, None)) @@ -174,7 +174,9 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) +paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) @@ -353,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None) +paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None)) +paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,)) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index e0a3ef5a9c6c53c42ebea1a41cac0d18a77781b2..17188ac5f301102ae79c6ace676b84ee66e28801 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -16,12 +16,14 @@ if(WITH_GPU) dynload_cuda variable_visitor) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) + nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) + cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) endif() cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) @@ -34,7 +36,7 @@ if(WITH_GPU) endif() cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) if(WITH_GPU) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass) @@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo cc_library(build_strategy SRCS build_strategy.cc DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass - fuse_elewise_add_act_pass) + fuse_elewise_add_act_pass multi_batch_merge_pass) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 4fdab5cd94358d08eac7f8b041bf16d09042f0bd..5b5a10e22776bee5c61a55c163c1732692551e36 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -48,16 +48,23 @@ void BroadcastOpHandle::RunImpl() { var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); } + BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes); +} + +void BroadcastOpHandle::BroadcastOneVar( + const VarHandle &in_var_handle, + const std::vector &out_var_handles, + const std::vector &var_scopes) { auto *in_var = - var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_); + var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); PADDLE_ENFORCE_NOT_NULL(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); - InitOutputValue(*in_var_handle, out_var_handles); + InitOutputValue(in_var_handle, out_var_handles); if (platform::is_cpu_place(in_tensor.place())) { for (auto *out_var_handle : out_var_handles) { - if (out_var_handle->IsTheSameVar(*in_var_handle)) { + if (out_var_handle->IsTheSameVar(in_var_handle)) { continue; } auto &out_p = out_var_handle->place_; @@ -114,12 +121,12 @@ void BroadcastOpHandle::RunImpl() { } } - if (!out_handle->IsTheSameVar(*in_var_handle)) { - auto out_var = var_scopes.at(in_var_handle->scope_idx_) + if (!out_handle->IsTheSameVar(in_var_handle)) { + auto out_var = var_scopes.at(in_var_handle.scope_idx_) ->FindVar(out_var_handles[0]->name_); paddle::framework::TensorCopy( - in_tensor, in_var_handle->place_, - *(dev_ctxes_.at(in_var_handle->place_)), + in_tensor, in_var_handle.place_, + *(dev_ctxes_.at(in_var_handle.place_)), &VariableVisitor::GetMutableTensor(out_var)); } }); diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index fe4e733e43417977df324fde808f52b228a27d19..020d351e891c7afab37c59c0ff8d8e5e7ba184f2 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; - private: + void BroadcastOneVar(const VarHandle &in_var_handle, + const std::vector &out_var_handles, + const std::vector &var_scopes); + std::vector local_scopes_; std::vector places_; #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 6a6b497fa897e3882995688bf36704b1d77ea962..fefd27fc86fb8dce3311fa580d90f518906dd862 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -121,6 +121,7 @@ std::unique_ptr BuildStrategy::Apply( USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); +USE_PASS(multi_batch_merge_pass); USE_PASS(multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 02c4bea16916d58a6d0fce8918f8fceb9ff9356e..f3ffaf6ecd7c4dd99c40fe58ba88c0cbdc14bde7 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -69,6 +69,8 @@ struct BuildStrategy { bool enable_data_balance_{false}; + bool fuse_broadcast_op_{false}; + // User normally doesn't need to call this API. // The PassBuilder allows for more customized insert, remove of passes // from python side. diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..51dfa2d0711f49aaefab0af3549283dbf77eee4a --- /dev/null +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +namespace details { + +void FusedBroadcastOpHandle::RunImpl() { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + + if (places_.size() == 1UL) return; + + auto in_var_handles = DynamicCast(inputs_); + auto out_var_handles = DynamicCast(outputs_); + + WaitInputVarGenerated(); + + std::vector var_scopes; + for (auto *s : local_scopes_) { + var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); + } + + size_t place_num = places_.size(); + PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size()); + + for (size_t i = 0; i < in_var_handles.size(); ++i) { + BroadcastOneVar( + *in_var_handles[i], + std::vector(out_var_handles.begin() + i * place_num, + out_var_handles.begin() + (i + 1) * place_num), + var_scopes); + } +} + +std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; } + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.h b/paddle/fluid/framework/details/fused_broadcast_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..e37259526a5f6f57d51a0ca8bca96a18211a4790 --- /dev/null +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.h @@ -0,0 +1,57 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/details/broadcast_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/device_context.h" + +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +struct FusedBroadcastOpHandle : public BroadcastOpHandle { + public: +#ifdef PADDLE_WITH_CUDA + FusedBroadcastOpHandle(ir::Node *node, + const std::vector local_scopes, + const std::vector &places, + const platform::NCCLContextMap *nccl_ctx) + : BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {} +#else + FusedBroadcastOpHandle(ir::Node* node, const std::vector local_scopes, + const std::vector& places) + : BroadcastOpHandle(node, local_scopes, places) {} +#endif + std::string Name() const override; + + protected: + void RunImpl() override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 134fcee826715672a6e021e9bf694bb771ebb830..f2d5b182e577714d6138e99932af637a711cc9f2 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" +#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -252,9 +253,9 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::vector sorted_ret; for (size_t i = 0; i < ret.size(); ++i) { if (i < last_backward) { - if (boost::get(ret[i]->Op()->GetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kOptimize)) { + if (static_cast(boost::get(ret[i]->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kOptimize))) { optimize_ops.push_back(ret[i]); } else { sorted_ret.push_back(ret[i]); @@ -347,7 +348,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( BuildStrategy::GradientScaleStrategy::kCustomized) { // TODO(paddle-dev): Why is there no input for this op_handle? auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; - CreateScaleLossGradOp(&result, loss_grad_name); + CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]); } // This assumes the backward generating code will ensure IsScaleLossOp // is true only for the op that scale the final scalar loss. @@ -436,10 +437,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( if ((use_gpu && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || is_dist_train) { - for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { - auto &to_bcast_set = bcast_var_name_set[dev_id]; - for (auto &bcast_name : to_bcast_set) { - CreateBroadcastOp(&result, bcast_name, dev_id); + if (strategy_.fuse_broadcast_op_) { + CreateFusedBroadcastOp(&result, bcast_var_name_set); + } else { + for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { + auto &to_bcast_set = bcast_var_name_set[dev_id]; + for (auto &bcast_name : to_bcast_set) { + CreateBroadcastOp(&result, bcast_name, dev_id); + } } } } @@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, } } +void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp( + ir::Graph *result, + const std::vector> &bcast_varnames) const { +#ifdef PADDLE_WITH_CUDA + auto *op_handle = new FusedBroadcastOpHandle( + result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), + local_scopes_, places_, nccl_ctxs_); +#else + auto *op_handle = new FusedBroadcastOpHandle( + result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation), + local_scopes_, places_); +#endif + result->Get(kGraphOps).emplace_back(op_handle); + + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + SetCommunicationContext(op_handle, p); + } + + for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) { + for (auto &p_name : bcast_varnames[dev_id]) { + auto *in = + result->Get(kGraphVars).at(dev_id).at(p_name).back().get(); + op_handle->AddInput(in); + for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) { + auto &p = places_[out_dev_id]; + auto &vars = + result->Get(kGraphVars).at(out_dev_id).at(p_name); + auto *out_var = new VarHandle( + result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), + vars.size(), out_dev_id, p_name, p); + vars.emplace_back(out_var); + op_handle->AddOutput(out_var); + } + } + } +} + void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { @@ -602,7 +645,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( - ir::Graph *result, const std::string &loss_grad_name) const { + ir::Graph *result, const std::string &loss_grad_name, + ir::Node *out_var_node) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); @@ -617,10 +661,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); - CreateOpOutput( - result, op_handle, - result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable), - places_[i], i); + CreateOpOutput(result, op_handle, + result->CreateVarNode(out_var_node->Var()), places_[i], i); } } diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index cdf9f13cde608b546d17a1e53e0f6acea9e12566..03b2de2f04da4bac8d342a76c80fd12beaeba4b7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -61,7 +61,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { size_t num_places) const; void CreateScaleLossGradOp(ir::Graph *result, - const std::string &loss_grad_name) const; + const std::string &loss_grad_name, + ir::Node *out_var_node) const; VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const; @@ -78,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass { void CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const; + void CreateFusedBroadcastOp( + ir::Graph *result, + const std::vector> &bcast_varnames) const; + bool IsSparseGradient(const std::string &og) const; size_t GetAppropriateDeviceID( diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 25f0ba418433571343c5b2bbfdbf9fb940eaec52..c99406799ba5f664c4b9f80e0567b293e4ffea51 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -80,7 +80,6 @@ message OpProto { optional bool duplicable = 3 [ default = false ]; optional bool intermediate = 4 [ default = false ]; optional bool dispensable = 5 [ default = false ]; - optional string reuse = 6; } // AttrProto describes the C++ type Attribute. diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 3aa2c7b9ea013dd977ef0051700df54e26a81307..ce006b7a3fbc16f3c9149933390969b14a46b484 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -36,18 +36,17 @@ pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) +pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(conv_bias_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference) + pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) endif() cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) -if(WITH_MKLDNN) - pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) -endif() set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 398f7095968e62f92d610f560d7574b27706d13e..265a128e95e6205159de67d87d5cb8ca1ad7d475 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -24,79 +24,23 @@ namespace paddle { namespace framework { namespace ir { -std::vector FindDistTrainSendVars( - const std::vector &nodes) { - std::vector send_vars; - // since parameters are all in block 0, - // it's enough to only scan send ops in block 0 - for (auto &node : nodes) { - auto op_vars = node->Op()->InputArgumentNames(); - send_vars.reserve(send_vars.size() + - std::distance(op_vars.begin(), op_vars.end())); - send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); - } - return send_vars; -} - -std::vector FindDistTrainRecvVars( - const std::vector &nodes) { - std::vector recv_vars; - for (auto &node : nodes) { - auto op_vars = node->Op()->OutputArgumentNames(); - recv_vars.reserve(recv_vars.size() + - std::distance(op_vars.begin(), op_vars.end())); - recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); - } - return recv_vars; -} - -bool IsDistTrainOp(ir::Node *node, const std::vector &send_vars, - const std::vector &recv_vars) { - if (send_vars.size() == 0 || recv_vars.size() == 0) { - return false; - } - - /** - * Check any of opvars contains `.block` and in sendvars - */ - auto checker = [](const std::vector &opvars, - const std::vector &rpc_vars) -> bool { - for (auto &var : opvars) { - // a variable name with the suffix `.block` means it's a splited - // variable by (DistributeTranspiler) - // [python/paddle/fluid/transpiler/distribute_transpiler.py] - if (var.find(".block") != std::string::npos && - std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { - return true; - } - } - return false; - }; - - std::vector input_var_names; - std::vector output_var_names; - for (ir::Node *input : node->inputs) { - input_var_names.push_back(input->Name()); - } - for (ir::Node *output : node->outputs) { - output_var_names.push_back(output->Name()); - } - - return checker(output_var_names, send_vars) || - checker(input_var_names, recv_vars); -} - Graph::Graph(const ProgramDesc &program) : program_(program) { // Make the nodes id start from 0. Node::ResetId(); + auto var_nodes = InitFromProgram(program_); + ResolveHazard(var_nodes); +} +std::map> Graph::InitFromProgram( + const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); std::unordered_map all_vars; + // var nodes for each var name, will have multiple versions in SSA + std::map> var_nodes; for (auto *var : program.Block(0).AllVars()) { all_vars.emplace(var->Name(), var); } - std::map> var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); // For input args, reuse the same var name if it was created before. @@ -134,7 +78,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var->inputs.push_back(node); } } + return std::move(var_nodes); +} +void Graph::ResolveHazard( + const std::map> &var_nodes) { /** * We should handle write after read(WAR) and write after write(WAW) here. * Because some of the operators of the program can be executed parallelly. @@ -153,6 +101,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { auto it_old = versions.rbegin(); ++it_old; for (; it_old != versions.rend(); it_new = it_old, ++it_old) { + VLOG(3) << "deal with var: " << (*it_new)->Name(); ir::Node *write_op = (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; const auto &read_ops = (*it_old)->outputs; diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index ab687e760a761d4e445726bd5149966adc2403d0..9d7aa5d32deb274fbf29481b0d4754c05d1e21b5 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -160,6 +160,12 @@ class Graph { return nullptr; } + std::map> InitFromProgram( + const ProgramDesc &program); + + void ResolveHazard( + const std::map> &var_nodes); + private: // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index c54766d95a61ac1a4b61566c6de62cbc86685a1d..01e878089171e4620f32b57a65d92d1c86d307db 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -120,19 +120,25 @@ size_t GraphNum(const Graph &graph) { std::deque q_nodes; std::vector> graph_nodes; std::unordered_set g_nodes; + // q_set used to record records in the queue. + std::unordered_set q_set; size_t graph_count = 0; - auto traverse_nodes = [&visited_nodes, - &q_nodes](const std::vector &nodes) { - std::copy_if( - nodes.begin(), nodes.end(), std::back_inserter(q_nodes), - [&visited_nodes](Node *node) { return !visited_nodes.count(node); }); + auto traverse_nodes = [&visited_nodes, &q_nodes, + &q_set](const std::vector &nodes) { + for (auto n : nodes) { + if (visited_nodes.count(n) == 0 && q_set.count(n) == 0) { + q_nodes.push_back(n); + q_set.insert(n); + } + } }; while (visited_nodes.size() != nodes.size()) { if (!q_nodes.empty()) { auto cur_node = q_nodes.front(); q_nodes.pop_front(); + q_set.erase(cur_node); visited_nodes.insert(cur_node); g_nodes.insert(cur_node); traverse_nodes(cur_node->inputs); @@ -146,6 +152,7 @@ size_t GraphNum(const Graph &graph) { for (auto &n : nodes) { if (visited_nodes.count(n) == 0) { q_nodes.push_back(n); + q_set.insert(n); break; } } diff --git a/paddle/fluid/framework/ir/graph_helper_test.cc b/paddle/fluid/framework/ir/graph_helper_test.cc index cea902809339f9d45b0e2525163f08a3c1c44c95..260a73ae763bd2cdea9948e4d928377a7c718dda 100644 --- a/paddle/fluid/framework/ir/graph_helper_test.cc +++ b/paddle/fluid/framework/ir/graph_helper_test.cc @@ -200,15 +200,15 @@ TEST(GraphHelperTest, GraphNum) { Graph g(prog); BuildZeroGraph(&g); - ASSERT_EQ(GraphNum(g), 0); + ASSERT_EQ(GraphNum(g), 0UL); Graph g2(prog); BuildOneGraph(&g2); - ASSERT_EQ(GraphNum(g2), 1); + ASSERT_EQ(GraphNum(g2), 1UL); Graph g3(prog); BuildTwoGraphs(&g3); - ASSERT_EQ(GraphNum(g3), 2); + ASSERT_EQ(GraphNum(g3), 2UL); } } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index cadda49c399a6d65079cacedfea61f4fd580a69a..7ed2f96eb24239d87965192d73f4ba200ff5dbeb 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -124,7 +124,7 @@ TEST(GraphTest, Basic) { ASSERT_EQ(n->outputs.size(), 0UL); } } - ASSERT_EQ(nodes.size(), 5); + ASSERT_EQ(nodes.size(), 5UL); } TEST(GraphTest, WriteAfterRead) { diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd5b76426eb55cebdabfccd700439a4c418a10f0 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -0,0 +1,315 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/multi_batch_merge_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +static const char kNumRepeats[] = "num_repeats"; +typedef std::unordered_map> SSAVarList; + +ir::Node* SameNameVar(std::unordered_set all, ir::Node* target) { + for (auto n : all) { + if (target->IsVar() && target->Name() == n->Name()) { + return n; + } + } + return nullptr; +} + +VarDesc CopyVarDesc(VarDesc* var_desc) { + VarDesc repeated_var(var_desc->Name()); + // copy other variable attributes + if (var_desc->GetType() != proto::VarType::READER) { + repeated_var.SetType(var_desc->GetType()); + repeated_var.SetShape(var_desc->GetShape()); + repeated_var.SetDataType(var_desc->GetDataType()); + repeated_var.SetLoDLevel(var_desc->GetLoDLevel()); + repeated_var.SetPersistable(var_desc->Persistable()); + } else { + // TODO(typhoonzero): copy reader var + } + return repeated_var; +} + +VarDesc UpdateGradVarDesc( + VarDesc* var_desc, int repeat, + const std::unordered_set& grad_names, + const std::unordered_set& bn_vars_need_rename) { + if (grad_names.find(var_desc->Name()) != grad_names.end() || + bn_vars_need_rename.find(var_desc->Name()) != bn_vars_need_rename.end()) { + std::string new_gname = + string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat); + VarDesc repeated_var = CopyVarDesc(var_desc); + repeated_var.SetName(new_gname); + VLOG(3) << "update " << var_desc->Name() << " to repeat " << repeat; + return repeated_var; + } + return *var_desc; +} + +std::unique_ptr BatchMergePass::ApplyImpl( + std::unique_ptr graph) const { + int num_repeats = Get(kNumRepeats); + std::vector forward_backward_ops; + std::vector optimize_ops; + std::vector lr_ops; // ops other than forward/backward/optimize + std::unordered_set grad_names; + + std::vector nodes = TopologySortOperations(*graph); + auto origin_nodes = graph->ReleaseNodes(); + VLOG(3) << "origin nodes count: " << origin_nodes.size(); + ir::Graph& result = *graph; + + // 1. record op nodes of different roles + for (auto node : nodes) { + if (node->IsVar()) continue; + int op_role = boost::get(node->Op()->GetAttr( + framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + if ((op_role == static_cast(framework::OpRole::kForward)) || + (op_role & static_cast(framework::OpRole::kBackward)) || + (op_role & static_cast(framework::OpRole::kLoss))) { + forward_backward_ops.push_back(node); + } else if ((op_role & static_cast(framework::OpRole::kOptimize)) || + (op_role & static_cast(framework::OpRole::kDist)) || + (op_role & static_cast(framework::OpRole::kRPC))) { + optimize_ops.push_back(node); + auto op_role_var = node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName()); + auto op_role_vars = boost::get>(op_role_var); + for (size_t i = 0; i < op_role_vars.size(); i += 2) { + grad_names.insert(op_role_vars[i + 1]); + } + } else if (op_role & static_cast(framework::OpRole::kLRSched)) { + lr_ops.push_back(node); + } else { // NOLINT + PADDLE_THROW("Invalid op_role: %d", static_cast(op_role)); + } + } + + // 2. copy forward backward + ir::Node* prev_repeat_last_op_node = nullptr; + // record origin_grad -> repeated grad list map. + std::map> grad_repeated_map; + std::map> created; + std::unordered_set bn_vars_need_rename; + for (int i = 0; i < num_repeats; ++i) { + std::unordered_set copied; + for (size_t node_idx = 0; node_idx < forward_backward_ops.size(); + ++node_idx) { + auto node = forward_backward_ops[node_idx]; + OpDesc repeated_op(*(node->Op()), node->Op()->Block()); + // 3. rename grad outputs to current repeat. + for (auto outname : repeated_op.OutputArgumentNames()) { + if (grad_names.find(outname) != grad_names.end()) { + std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i); + repeated_op.RenameOutput(outname, new_gname); + } + } + // 3.5 let batch_norm ops use independent vars, note batch_norm_grad do + // not need this update + if (node->Name() == "batch_norm") { + // NOTE: assume bn op created by layers use save var as output mean and + // variance + std::string new_mean_name = + string::Sprintf("%s.repeat.%d", repeated_op.Input("Mean")[0], i); + std::string new_var_name = string::Sprintf( + "%s.repeat.%d", repeated_op.Input("Variance")[0], i); + bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]); + bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]); + VLOG(3) << "renaming " << repeated_op.Input("Mean")[0] << " to " + << new_mean_name; + repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name); + repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name); + repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0], + new_mean_name); + repeated_op.RenameOutput(repeated_op.Output("VarianceOut")[0], + new_var_name); + } + + // 3.9 do copy + auto repeated_node = result.CreateOpNode(&repeated_op); + copied.insert(node); + + // 4. add deps between repeats + if (node_idx == forward_backward_ops.size() - 1) { + prev_repeat_last_op_node = repeated_node; + } + if (node_idx == 0 && prev_repeat_last_op_node) { + auto* depvar = result.CreateControlDepVar(); + prev_repeat_last_op_node->outputs.push_back(depvar); + depvar->inputs.push_back(prev_repeat_last_op_node); + repeated_node->inputs.push_back(depvar); + depvar->outputs.push_back(repeated_node); + } + + for (auto in_node : node->inputs) { + if (in_node->IsCtrlVar()) { + continue; + } + ir::Node* var = nullptr; + auto updated_var = UpdateGradVarDesc(in_node->Var(), i, grad_names, + bn_vars_need_rename); + // should be initialized by startup, how to initilize tensor in the + // scope? + if (node->Name() == "batch_norm" && + bn_vars_need_rename.find(in_node->Name()) != + bn_vars_need_rename.end()) { + // Create bn mean/variance for each repeat + var = result.CreateVarNode(&updated_var); + created[updated_var.Name()].push_back(var); + copied.insert(in_node); + repeated_node->inputs.push_back(var); + var->outputs.push_back(repeated_node); + continue; + } + + // for other ops + if (in_node->inputs.empty() && i > 0) { + // do not copy head vars (inputs, params) in repeats > 0 + var = created.at(in_node->Name()).back(); + } else { + if (copied.find(in_node) == copied.end()) { + var = result.CreateVarNode(&updated_var); + if (grad_names.find(in_node->Var()->Name()) != grad_names.end()) { + grad_repeated_map[in_node].push_back(var); + } + copied.insert(in_node); + created[updated_var.Name()].push_back(var); + } else { + var = created.at(updated_var.Name()).back(); + } + } + repeated_node->inputs.push_back(var); + var->outputs.push_back(repeated_node); + } + for (auto out_node : node->outputs) { + if (out_node->IsCtrlVar()) { + continue; + } + ir::Node* var = nullptr; + auto updated_var = UpdateGradVarDesc(out_node->Var(), i, grad_names, + bn_vars_need_rename); + if (copied.find(out_node) == copied.end()) { + var = result.CreateVarNode(&updated_var); + if (grad_names.find(out_node->Var()->Name()) != grad_names.end()) { + grad_repeated_map[out_node].push_back(var); + } + copied.insert(out_node); + created[updated_var.Name()].push_back(var); + } else { + var = created.at(updated_var.Name()).back(); + } + repeated_node->outputs.push_back(var); + var->inputs.push_back(repeated_node); + } + } + } + + // 5. create GRAD merge op node + for (auto kv : grad_repeated_map) { + OpDesc sum_op; + sum_op.SetType("sum"); + std::vector repeated_grad_names; + for (auto r : kv.second) { + repeated_grad_names.push_back(r->Var()->Name()); + } + sum_op.SetInput("X", repeated_grad_names); + sum_op.SetOutput("Out", {kv.first->Var()->Name()}); + sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kBackward)); + auto sum_op_node = result.CreateOpNode(&sum_op); + for (auto r : kv.second) { + sum_op_node->inputs.push_back(r); + r->outputs.push_back(sum_op_node); + } + auto sum_out_var_node = result.CreateVarNode(kv.first->Var()); + sum_op_node->outputs.push_back(sum_out_var_node); + sum_out_var_node->inputs.push_back(sum_op_node); + created[sum_out_var_node->Name()].push_back(sum_out_var_node); + + OpDesc scale_op; + scale_op.SetType("scale"); + scale_op.SetInput("X", {sum_out_var_node->Var()->Name()}); + // NOTE: inplace scale. + scale_op.SetOutput("Out", {sum_out_var_node->Var()->Name()}); + scale_op.SetAttr("scale", static_cast(1.0f / num_repeats)); + scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kBackward)); + auto scale_op_node = result.CreateOpNode(&scale_op); + scale_op_node->inputs.push_back(sum_out_var_node); + sum_out_var_node->outputs.push_back(scale_op_node); + auto scale_out_var_node = result.CreateVarNode(sum_out_var_node->Var()); + scale_op_node->outputs.push_back(scale_out_var_node); + scale_out_var_node->inputs.push_back(scale_op_node); + created[scale_out_var_node->Name()].push_back(scale_out_var_node); + } + // 6. add optimize ops + { + auto copy_node = [&result, &created](ir::Node* node) { + auto op_node = result.CreateOpNode(node->Op()); + // copy op ins/outs + // NOTE: for send/recv ops, the OpDesc uses ctrldepvar to describe + // dependencies, so create those depvars if OpDesc have in/outs. + for (auto in_node : node->inputs) { + if (in_node->IsCtrlVar() && !in_node->Var()) { + continue; + } + ir::Node* var = nullptr; + if (created.find(in_node->Name()) == created.end()) { + var = result.CreateVarNode(in_node->Var()); + created[in_node->Name()].push_back(var); + } else { + var = created.at(in_node->Name()).back(); + } + op_node->inputs.push_back(var); + var->outputs.push_back(op_node); + } + for (auto out_node : node->outputs) { + if (out_node->IsCtrlVar() && !out_node->Var()) { + continue; + } + auto var = result.CreateVarNode(out_node->Var()); + created[out_node->Name()].push_back(var); + op_node->outputs.push_back(var); + var->inputs.push_back(op_node); + } + }; + for (auto node : lr_ops) { + copy_node(node); + } + for (auto node : optimize_ops) { + copy_node(node); + } + } + + result.ResolveHazard(created); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_batch_merge_pass, paddle::framework::ir::BatchMergePass) + .RequirePassAttr(paddle::framework::ir::kNumRepeats); diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.h b/paddle/fluid/framework/ir/multi_batch_merge_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c1e5aef20dbc60c18ed03038818bfd8ab217bf28 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +// BatchMergePass is used to copy forward and backward ops for several +// times to run several batches to simulate large batch size training +// as if we have more than 1 GPUs. +// User can define how many batches to run, gradients will be merged +// through those repeats, and then do optimization using merged gradients. +// This pass is extremely useful when doing large batch-size distributed +// sync training, we can simulate even large batch size as if we have more +// GPUs. + +class BatchMergePass : public Pass { + public: + virtual ~BatchMergePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 5d6da9f1d76a3c0fc64b7ff35264e385cf19a14b..d6d42f5e92080aa57445e2d6ce59aa3faa89d22d 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -44,6 +44,7 @@ class Node { return op_desc_.get(); } + // Please don't use this API! int id() const { return id_; } bool IsOp() const { return type_ == Type::kOperation; } @@ -92,6 +93,7 @@ class Node { Node() = delete; static int count_; + // Please don't use this API or make this public. static void ResetId() { count_ = 0; } DISABLE_COPY_AND_ASSIGN(Node); }; diff --git a/paddle/fluid/framework/lod_tensor_array.h b/paddle/fluid/framework/lod_tensor_array.h index 6d7b6a4ada8729e3698dab5d2b1861aac632be79..0ad6a709008406257d6c0a220bce38bb24e188cd 100644 --- a/paddle/fluid/framework/lod_tensor_array.h +++ b/paddle/fluid/framework/lod_tensor_array.h @@ -18,6 +18,82 @@ limitations under the License. */ namespace paddle { namespace framework { + +// NOTE The vector can't be replaced with the class LoDTensorArray +// directly, because there are many vector used accross the project, +// and some of them are treated as LoDTensorArray. +#if !defined(PADDLE_ON_INFERENCE) + using LoDTensorArray = std::vector; -} + +#else // !PADDLE_ON_INFERENCE + +#pragma message "LoDTensorArray is replaced with the inference one." +/* + * A LoDTensorArray which will not deallocate buffer when resized, fix the data + * diff in inference, and more performance friendly in the concurrency + * scenerios. + */ +class LoDTensorArray { + public: + LoDTensorArray() = default; + + using iterator = std::vector::iterator; + using const_iterator = std::vector::const_iterator; + + const_iterator begin() const { return array_.begin(); } + const_iterator end() const { return array_.begin() + size_; } + iterator begin() { return array_.begin(); } + iterator end() { return array_.begin() + size_; } + + void push_back(const LoDTensor& x) { + if (size_ < array_.size()) { + array_[size_++] = x; + } else { + array_.push_back(x); + ++size_; + } + } + void resize(size_t size) { + if (array_.size() < size) { + array_.resize(size); + } + size_ = size; + } + + void emplace_back() { array_.emplace_back(); } + + void emplace_back(LoDTensor&& x) { array_.emplace_back(std::move(x)); } + + LoDTensor& back() { return array_.back(); } + + size_t space() const { return array_.size(); } + + void reserve(size_t size) { + // Naive warning to tell user this array might be to large. The memory and + // buffer used by this TensorArray will not be deleted during the training + // and inference phase, so attention not to make it expand too long. + if (size > 800UL) { + LOG(WARNING) << "TensorArray has more than 800 items"; + } + array_.reserve(size); + } + + bool empty() const { return size_ == 0UL; } + void clear() { size_ = 0UL; } + + LoDTensor& operator[](size_t id) { return array_[id]; } + const LoDTensor& operator[](size_t id) const { return array_[id]; } + LoDTensor& at(size_t id) { return array_.at(id); } + const LoDTensor& at(size_t id) const { return array_.at(id); } + + size_t size() const { return size_; } + + private: + size_t size_{0}; + std::vector array_; +}; +#endif // !PADDLE_ON_INFERENCE + +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/mixed_vector.h b/paddle/fluid/framework/mixed_vector.h index 77386f4f069489b6ff7b927a281bdc286ff816e0..e1aac6dc5a92fb616f00de5806f044b83c2f503f 100644 --- a/paddle/fluid/framework/mixed_vector.h +++ b/paddle/fluid/framework/mixed_vector.h @@ -542,6 +542,33 @@ class CPUVector : public std::vector> { this->reserve(this->size() + size_t(end - begin)); this->insert(this->end(), begin, end); } + + const T *CUDAData(platform::Place place) const { + PADDLE_THROW( + "Vector::CUDAData() method is not supported in CPU-only version"); + } + + T *CUDAMutableData(platform::Place place) { + PADDLE_THROW( + "Vector::CUDAMutableData() method is not supported in CPU-only " + "version"); + } + + const T *Data(platform::Place place) const { + PADDLE_ENFORCE( + platform::is_cpu_place(place), + "Vector::Data() method is not supported when not in CPUPlace"); + return this->data(); + } + + T *MutableData(platform::Place place) { + PADDLE_ENFORCE( + platform::is_cpu_place(place), + "Vector::MutableData() method is not supported when not in CPUPlace"); + return this->data(); + } + + const void *Handle() const { return static_cast(this); } }; template diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 2840d503f1454271afb309efdd435225ab077dc0..7fb42feb95b4d54aec693228721c052f683f4d80 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -146,22 +146,5 @@ void NaiveExecutor::CleanFeedFetchOps() { ops_.swap(ops); } -void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) { -#ifdef PADDLE_WITH_MKLDNN - VLOG(3) << "use_mkldnn=True"; - for (size_t block_id = 0; block_id < program.Size(); ++block_id) { - auto *block = const_cast(program).MutableBlock(block_id); - for (auto *op : block->AllOps()) { - if (op->HasAttr("use_mkldnn")) { - op->SetAttr("use_mkldnn", true); - } - } - } -#else - LOG(WARNING) - << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; -#endif -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 9374f3f4a35cc0f90e5b2d6e8b397784b8eae123..ddfa6e1f4d8b73f594fc381ab505797491cdd378 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -48,8 +48,6 @@ class NaiveExecutor { void CleanFeedFetchOps(); - void EnableMKLDNN(const ProgramDesc& program); - protected: void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 440e0509be727ec2b84abc76fca44edda11f8a0a..30c8a26c3d2f0068674aa70b4ff875a2f73c1dca 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -121,10 +121,6 @@ class OpDesc { BlockDesc *Block() { return this->block_; } - const BlockDesc &BlockRef() const { return *this->block_; } - - void SetBlock(BlockDesc *block) { this->block_ = block; } - private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index df2a7a27ca4a6011b214202ac9bf4f30dc482ece..ca31303f77c4a30eb64c43404e214779ea78aeaf 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -21,7 +21,6 @@ namespace framework { void OpProtoAndCheckerMaker::Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); - CheckReuseVars(); } OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( @@ -40,40 +39,6 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( return OpProtoAndCheckerMaker::VariableBuilder{output}; } -void OpProtoAndCheckerMaker::Reuse(const std::string& name, - const std::string& reused_name) { - bool found = false; - proto::OpProto::Var* var; - - for (auto& var : proto_->inputs()) { - if (var.name() == reused_name) { - found = true; - break; - } - } - PADDLE_ENFORCE(found == true, - "Input/Output name: %s reused_name: %s, one of them is not " - "exists or not matched.", - name, reused_name); - - found = false; - for (int i = 0; i < proto_->outputs().size(); ++i) { - var = proto_->mutable_outputs()->Mutable(i); - if (var->name() == name) { - PADDLE_ENFORCE(!var->has_reuse(), - "Output(%s) has been set reused var of %s", name, - var->reuse()); - found = true; - var->set_reuse(reused_name); - break; - } - } - PADDLE_ENFORCE(found == true, - "Input/Output name: %s reused_name: %s, one of them is not " - "exists or not matched.", - name, reused_name); -} - void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { std::unordered_set names; auto checker = [&](const std::string& name) { @@ -91,24 +56,6 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { } } -void OpProtoAndCheckerMaker::CheckReuseVars() { - std::unordered_set names; - for (auto& input : proto_->inputs()) { - names.insert(input.name()); - } - auto checker = [&](const std::string& name, const std::string& reused) { - PADDLE_ENFORCE( - names.count(reused), - "Output [%s] reuse Input [%s], but the input is not registered.", name, - reused); - }; - for (auto& output : proto_->outputs()) { - if (output.has_reuse()) { - checker(output.name(), output.reuse()); - } - } -} - void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, OpAttrChecker* attr_checker) { proto_ = proto; @@ -124,6 +71,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward), + static_cast(OpRole::kOptimize) | + static_cast(OpRole::kLRSched), static_cast(OpRole::kNotSpecified)}) .SetDefault(static_cast(OpRole::kNotSpecified)); AddAttr>(OpRoleVarAttrName(), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 4ed3cc45d66849267ef4945a03da1db76b53e4ea..4c59c73d8779eceb267ad532aabccabbd54b0df2 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -14,25 +14,26 @@ limitations under the License. */ #pragma once #include -#include - #include "glog/logging.h" #include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/framework.pb.h" namespace paddle { namespace framework { +////////////////////////// +// Don't add more roles to make this too complicated! +////////////////////////// enum class OpRole { kForward = 0x0000, kBackward = 0x0001, kOptimize = 0x0002, // RPC role is for send/recv releated op - kRPC = 0x0003, + kRPC = 0x0004, // Dist role is for split_byref/split_selected_rows/concat // used for distributed training. - kDist = 0x0004, + kDist = 0x0008, // Tag all learning rate scheduler operators. - kLRSched = 0x0005, + kLRSched = 0x0010, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and @@ -73,11 +74,6 @@ class OpProtoAndCheckerMaker { var_->set_dispensable(true); return *this; } - - VariableBuilder &Reuse(const std::string &name) { - var_->set_reuse(name); - return *this; - } }; VariableBuilder AddInput(const std::string &name, const std::string &comment); @@ -85,8 +81,6 @@ class OpProtoAndCheckerMaker { VariableBuilder AddOutput(const std::string &name, const std::string &comment); - void Reuse(const std::string &name, const std::string &reused_name); - template TypedAttrChecker &AddAttr(const std::string &name, const std::string &comment, @@ -105,8 +99,6 @@ class OpProtoAndCheckerMaker { void CheckNoDuplicatedInOutAttrs(); void Validate(); - void CheckReuseVars(); - proto::OpProto *proto_; OpAttrChecker *op_checker_; bool validated_{false}; diff --git a/paddle/fluid/framework/op_proto_maker_test.cc b/paddle/fluid/framework/op_proto_maker_test.cc index b71c7b646857e11f291748c4c7c2af92b6d53231..a8030d377fdb4d4aef74b315e21792dad10fac96 100644 --- a/paddle/fluid/framework/op_proto_maker_test.cc +++ b/paddle/fluid/framework/op_proto_maker_test.cc @@ -47,120 +47,3 @@ TEST(ProtoMaker, DuplicatedInOut) { ASSERT_THROW(proto_maker(&op_proto, &op_checker), paddle::platform::EnforceNotMet); } - -class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "input of test op"); - AddOutput("XOut", "output of test op").Reuse("X"); - } -}; - -class TestInplaceProtoMaker2 - : public paddle::framework::OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "input of test op"); - AddOutput("XOut", "output of test op").Reuse("X"); - AddOutput("NoOut", "output of test op").Reuse("NotExists"); - } -}; - -TEST(ProtoMaker, InplaceOutput) { - paddle::framework::proto::OpProto op_proto, op_proto2; - paddle::framework::OpAttrChecker op_checker; - TestInplaceProtoMaker proto_maker; - TestInplaceProtoMaker2 proto_maker2; - - proto_maker(&op_proto, &op_checker); - - ASSERT_THROW(proto_maker2(&op_proto2, &op_checker), - paddle::platform::EnforceNotMet); -} - -// normal reuse -class TestReuseProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { - public: - void Make() { - AddInput("X", "input of test op"); - AddInput("Y", "input of test op"); - AddOutput("Out", "output of test op"); - AddOutput("XOut", "output of test op"); - // avoid destructor exception. - // Validate(); - TestReuse(); - } - - virtual void TestReuse() {} -}; - -// test duplicate reuse error -class TestReuseProtoMaker2 : public TestReuseProtoMaker { - public: - void TestReuse() { - Reuse("Out", "X"); - Reuse("Out", "Y"); - } -}; - -// NotExists Input -class TestReuseProtoMaker3 : public TestReuseProtoMaker { - public: - void TestReuse() { - Reuse("Out", "NotExists"); - Reuse("XOut", "X"); - } -}; - -// NotExists Output -class TestReuseProtoMaker4 : public TestReuseProtoMaker { - public: - void TestReuse() { Reuse("NotExists", "X"); } -}; - -TEST(ProtoMaker, Reuse) { - paddle::framework::proto::OpProto op_proto; - paddle::framework::OpAttrChecker op_checker; - TestReuseProtoMaker proto_maker; - proto_maker(&op_proto, &op_checker); -} - -// NOTE(dzhwinter): -// There is a Fatal CHECK on base class destructor, which will call abort inside -// instead of -// throw an exception. If we throw an exception in Make(), we will trigger the -// CHECK and terminate the tests. -// -// I had tried to replace the default CHECK with a exception, however, it's -// still not supported by glog. -// the details: -// https://github.com/google/glog/issues/249 -// https://github.com/facebookresearch/TensorComprehensions/issues/351 -/* -TEST(ProtoMaker, ReuseWithException) { - paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4; - paddle::framework::OpAttrChecker op_checker; - TestReuseProtoMaker2 proto_maker2; - TestReuseProtoMaker3 proto_maker3; - TestReuseProtoMaker4 proto_maker4; - EXPECT_THROW(proto_maker2(&op_proto2, &op_checker), - paddle::platform::EnforceNotMet); - - EXPECT_THROW(proto_maker3(&op_proto3, &op_checker), - paddle::platform::EnforceNotMet); - - EXPECT_THROW(proto_maker4(&op_proto4, &op_checker), - paddle::platform::EnforceNotMet); -} - -void FailureFunction() { - throw std::runtime_error("Check failed in destructor."); - // return 0; -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - google::InstallFailureFunction(&FailureFunction); - return RUN_ALL_TESTS(); -} -*/ diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 093108cb54779eb0cf35dd83e63eb0b1abb66dcd..cffb96bedf7638ee52856f21662437085035eca6 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -109,18 +109,9 @@ ParallelExecutor::ParallelExecutor( if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { BCastParamsToDevices(bcast_vars); } - // Startup Program has been run. All local scopes has correct parameters. +// Startup Program has been run. All local scopes has correct parameters. - // Step 2. Create vars in each scope; - std::vector var_infos; - for (auto *var : main_program.Block(0).AllVars()) { - var_infos.emplace_back(); - var_infos.back().name_ = var->Name(); - var_infos.back().type_ = var->GetType(); - var_infos.back().persistable_ = var->Persistable(); - } - -// Step 3. Convert main_program to SSA form and dependency graph. Also, insert +// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp #ifdef PADDLE_WITH_CUDA std::unique_ptr graph = build_strategy.Apply( @@ -156,13 +147,22 @@ ParallelExecutor::ParallelExecutor( params, member_->local_scopes_, member_->use_cuda_); #endif - if (VLOG_IS_ON(5)) { - // If the loss_var_name is given, the number of graph should be only one. - if (loss_var_name.size()) { - PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, - "The number of graph should be only one"); + // Step 3. Create vars in each scope. Passes may also create new vars. + // skip control vars and empty vars + std::vector var_infos; + for (auto &node : graph->Nodes()) { + if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { + var_infos.emplace_back(); + var_infos.back().name_ = node->Var()->Name(); + var_infos.back().type_ = node->Var()->GetType(); + var_infos.back().persistable_ = node->Var()->Persistable(); } } + // If the loss_var_name is given, the number of graph should be only one. + if (loss_var_name.size()) { + PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, + "The number of graph should be only one"); + } if (exec_strategy.type_ == ExecutionStrategy::kDefault) { member_->executor_.reset(new details::ThreadedSSAGraphExecutor( diff --git a/paddle/fluid/framework/program_desc_test.cc b/paddle/fluid/framework/program_desc_test.cc index 7e689a37da8a16bd9b1ac6650b9322d2eb5a2c85..48bde2785e6a51afc0d2905ac31fe20a3c3019b6 100644 --- a/paddle/fluid/framework/program_desc_test.cc +++ b/paddle/fluid/framework/program_desc_test.cc @@ -103,7 +103,7 @@ TEST(ProgramDesc, copy_ctor) { ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); found_sub_block = true; - ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size()); + ASSERT_EQ(2UL, op->GetBlocksAttrIds("sub_blocks").size()); found_sub_blocks = true; } } diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc index 50aca4b5a4ba7a93a1584a03cc16fe5d712a32b5..d812417a38200bcfdbdeac78800190647510a144 100644 --- a/paddle/fluid/framework/reader_test.cc +++ b/paddle/fluid/framework/reader_test.cc @@ -40,7 +40,7 @@ TEST(READER, decorate_chain) { auto endpoints = root->GetEndPoints(); ASSERT_EQ(endpoints.size(), 2U); ASSERT_NE(endpoints.count(end_point1.get()), 0UL); - ASSERT_NE(endpoints.count(end_point2.get()), 0); + ASSERT_NE(endpoints.count(end_point2.get()), 0UL); } { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 14f9f36812d690fc4a7440f2e7e6a85e9993a535..9462620e829ec815e1553f6378a67463ea3b8aa3 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -78,6 +78,8 @@ class Scope { /// Drop all kids scopes belonged to this scope. void DropKids(); + std::list& kids() const { return kids_; } + /// Find if a scope exists in the kid scopes bool HasKid(const Scope* scope) const; diff --git a/paddle/fluid/framework/threadpool.cc b/paddle/fluid/framework/threadpool.cc index 18cdca3a658a6a89e6ab637a7f38825756acfea8..a588cb417aebe94bd4aeda02b1bc8ba07a04b960 100644 --- a/paddle/fluid/framework/threadpool.cc +++ b/paddle/fluid/framework/threadpool.cc @@ -25,7 +25,6 @@ DEFINE_int32(dist_threadpool_size, 0, namespace paddle { namespace framework { - std::unique_ptr ThreadPool::threadpool_(nullptr); std::once_flag ThreadPool::init_flag_; @@ -47,8 +46,7 @@ void ThreadPool::Init() { } } -ThreadPool::ThreadPool(int num_threads) - : total_threads_(num_threads), idle_threads_(num_threads), running_(true) { +ThreadPool::ThreadPool(int num_threads) : running_(true) { threads_.resize(num_threads); for (auto& thread : threads_) { // TODO(Yancey1989): binding the thread on the specify CPU number @@ -59,6 +57,7 @@ ThreadPool::ThreadPool(int num_threads) ThreadPool::~ThreadPool() { { // notify all threads to stop running + std::lock_guard l(mutex_); running_ = false; scheduled_.notify_all(); } @@ -69,36 +68,24 @@ ThreadPool::~ThreadPool() { } } -void ThreadPool::Wait() { - std::unique_lock lock(mutex_); - completed_.wait(lock, [=] { return Done() == true; }); -} - void ThreadPool::TaskLoop() { - while (running_) { + while (true) { std::unique_lock lock(mutex_); - scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; }); - if (!running_) { - break; + scheduled_.wait( + lock, [this] { return !this->tasks_.empty() || !this->running_; }); + + if (!running_ || tasks_.empty()) { + return; } + // pop a task from the task queue auto task = std::move(tasks_.front()); tasks_.pop(); - - --idle_threads_; lock.unlock(); // run the task task(); - - { - std::unique_lock lock(mutex_); - ++idle_threads_; - if (Done()) { - completed_.notify_all(); - } - } } } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 94111ee335b1a5df327b3e46d62069b4735c54f6..0687e628aaa4fb7b2e67938fa09a319c8bb35aff 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -57,15 +57,6 @@ class ThreadPool { ~ThreadPool(); - // Returns the number of threads created by the constructor. - size_t Threads() const { return total_threads_; } - - // Returns the number of currently idle threads. - size_t IdleThreads() { - std::unique_lock lock(mutex_); - return idle_threads_; - } - // Run pushes a function to the task queue and returns a std::future // object. To wait for the completion of the task, call // std::future::wait(). @@ -94,25 +85,13 @@ class ThreadPool { }); std::future> f = task.get_future(); tasks_.push(std::move(task)); - lock.unlock(); scheduled_.notify_one(); return f; } - // Wait until all the tasks are completed. - void Wait(); - private: DISABLE_COPY_AND_ASSIGN(ThreadPool); - // If the task queue is empty and avaialbe is equal to the number of - // threads, means that all tasks are completed. Note: this function - // is not thread-safe. Returns true if all tasks are completed. - // Note: don't delete the data member total_threads_ and use - // threads_.size() instead; because you'd need to lock the mutex - // before accessing threads_. - bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; } - // The constructor starts threads to run TaskLoop, which retrieves // and runs tasks from the queue. void TaskLoop(); @@ -125,14 +104,11 @@ class ThreadPool { static std::once_flag init_flag_; std::vector> threads_; - const size_t total_threads_; - size_t idle_threads_; std::queue tasks_; std::mutex mutex_; bool running_; std::condition_variable scheduled_; - std::condition_variable completed_; }; class ThreadPoolIO : ThreadPool { diff --git a/paddle/fluid/framework/threadpool_test.cc b/paddle/fluid/framework/threadpool_test.cc index 27a4ffd4fcbf293a3dea1744b29384d0bee0c137..884d61e23428a0ad758946295ca9c470767e93ef 100644 --- a/paddle/fluid/framework/threadpool_test.cc +++ b/paddle/fluid/framework/threadpool_test.cc @@ -19,10 +19,11 @@ limitations under the License. */ namespace framework = paddle::framework; -void do_sum(framework::ThreadPool* pool, std::atomic* sum, int cnt) { - std::vector> fs; +void do_sum(std::vector>* fs, std::mutex* mu, + std::atomic* sum, int cnt) { for (int i = 0; i < cnt; ++i) { - fs.push_back(framework::Async([sum]() { sum->fetch_add(1); })); + std::lock_guard l(*mu); + fs->push_back(framework::Async([sum]() { sum->fetch_add(1); })); } } @@ -40,18 +41,21 @@ TEST(ThreadPool, ConcurrentInit) { } TEST(ThreadPool, ConcurrentRun) { - framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); std::atomic sum(0); std::vector threads; + std::vector> fs; + std::mutex fs_mu; int n = 50; // sum = (n * (n + 1)) / 2 for (int i = 1; i <= n; ++i) { - std::thread t(do_sum, pool, &sum, i); + std::thread t(do_sum, &fs, &fs_mu, &sum, i); threads.push_back(std::move(t)); } for (auto& t : threads) { t.join(); } - pool->Wait(); + for (auto& t : fs) { + t.wait(); + } EXPECT_EQ(sum, ((n + 1) * n) / 2); } diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 9794a193bcfaae19552b1f6fbdf2dab2898033d5..dbbe8bcba69a1d87e21c8eae18834fb708e8b1e4 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -30,7 +30,7 @@ if (WITH_GPU AND TENSORRT_FOUND) endif() # Create static library -cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor) +cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array) if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. @@ -40,7 +40,7 @@ endif() # Create shared library cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} - DEPS ${fluid_modules} paddle_fluid_api) + DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) if(NOT APPLE) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 2e79d495d5ff00000000029ac0f6eb486aaea94a..ef4142f334e503380dc7ccd74c348404ffe52ee6 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -107,6 +107,9 @@ void Analyzer::Run(Argument* argument) { passes.push_back("mkldnn_placement_pass"); } #endif + // infer_clean_graph_pass should be the first default pass + // after mkldnn_placement_pass. + passes.push_back("infer_clean_graph_pass"); for (auto& pass : ir_passes_) { if (!disabled_ir_passes_.count(pass)) { passes.push_back(pass); diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index c51a4fdb2f6b27e54637481c23bf6f1f6ec97718..7114f5222c5904bb4422b9e67ad035b85bbb770c 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -67,7 +67,6 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // "attention_lstm_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // "embedding_fc_lstm_fuse_pass", // diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 0ddd5d53f836131fe37d412fc867cb38f11ee2b5..e2027b7cb4d584ffcc48624d2c01e65a61829975 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -18,7 +18,8 @@ if(APPLE) endif(APPLE) -set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB}) +set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager naive_executor ${GLOB_PASS_LIB} + ) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine analysis_predictor) @@ -31,10 +32,17 @@ function(inference_api_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cc_test(${TARGET_NAME} - SRCS ${inference_test_SRC} - DEPS "${inference_deps}" - ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + if (WITH_GPU) + cc_test(${TARGET_NAME} + SRCS ${inference_test_SRC} + DEPS "${inference_deps}" + ARGS --dirname=${PYTHON_TESTS_DIR}/book/ --fraction_of_gpu_memory_to_use=0.15) + else() + cc_test(${TARGET_NAME} + SRCS ${inference_test_SRC} + DEPS "${inference_deps}" + ARGS --dirname=${PYTHON_TESTS_DIR}/book/) + endif() if(inference_test_ARGS) set_tests_properties(${TARGET_NAME} PROPERTIES DEPENDS "${inference_test_ARGS}") @@ -42,7 +50,8 @@ function(inference_api_test TARGET_NAME) endif(WITH_TESTING) endfunction(inference_api_test) -cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope) +cc_library(reset_tensor_array SRCS details/reset_tensor_array.cc DEPS lod_tensor scope) +cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS reset_tensor_array lod_tensor scope) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS paddle_inference_api) cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc DEPS paddle_inference_api) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index eec665767164dc6e79738890947c54d7f7217037..54c37fe64590aa82d7100c93c4c5c4ee36491421 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -82,6 +82,7 @@ bool AnalysisPredictor::Init( // Get the feed_target_names and fetch_target_names PrepareFeedFetch(); + return true; } @@ -109,6 +110,10 @@ bool AnalysisPredictor::Run(const std::vector &inputs, return false; } VLOG(3) << "predict cost: " << timer.toc() << "ms"; + + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } @@ -322,6 +327,9 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( bool AnalysisPredictor::ZeroCopyRun() { executor_->Run(); + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 5a9f4d36959d4ee7ca16dec769d9d1283b8787cb..b7dc2067332278c1c38df4beefb5059efe76417f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/string/printf.h" @@ -88,6 +89,7 @@ class AnalysisPredictor : public PaddlePredictor { // Memory buffer for feed inputs. The temporary LoDTensor will cause serious // concurrency problems, so cache them. std::vector feed_tensors_; + details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 7cda9c5d8a8366bd097491f37f5352a10e4fb16c..d06ab8f8c8e3c0adf4a4074eb1450012126e83ea 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -157,6 +158,10 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, return false; } VLOG(3) << "predict cost: " << timer.toc() << "ms"; + + // Fix TensorArray reuse not cleaned bug. + tensor_array_batch_cleaner_.CollectTensorArrays(scope_.get()); + tensor_array_batch_cleaner_.ResetTensorArray(); return true; } diff --git a/paddle/fluid/inference/api/api_impl.h b/paddle/fluid/inference/api/api_impl.h index 7882f6a53c7ce9a2486158ea9b50c018d1814091..4e4ab47ca9c5e37f2714ebd48d250c23c7e9b117 100644 --- a/paddle/fluid/inference/api/api_impl.h +++ b/paddle/fluid/inference/api/api_impl.h @@ -26,11 +26,11 @@ limitations under the License. */ #include #include -#include "paddle/fluid/inference/api/paddle_inference_api.h" - #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/init.h" @@ -77,6 +77,7 @@ class NativePaddlePredictor : public PaddlePredictor { std::vector fetchs_; // Do not use unique_ptr, use parent scope to delete framework::Scope *sub_scope_{nullptr}; + details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 03f0f726eb61c2619c7719a865383090f86b5b7f..49683eab07a2f5bc008272038a27bdb277396284 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -52,6 +52,7 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") if (NOT WIN32) include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") @@ -61,8 +62,8 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") -if (NOT WIN32) - if (USE_TENSORRT AND WITH_GPU) +if (NOT WIN32) + if (USE_TENSORRT AND WITH_GPU) include_directories("${TENSORRT_INCLUDE_DIR}") link_directories("${TENSORRT_LIB_DIR}") endif() @@ -77,13 +78,14 @@ endif(NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/paddle/lib") add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) if(WITH_MKL) include_directories("${PADDLE_LIB}/third_party/install/mklml/include") - set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} + set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn") if(EXISTS ${MKLDNN_PATH}) @@ -107,7 +109,7 @@ if (NOT WIN32) set(EXTERNAL_LIB "-lrt -ldl -lpthread") set(DEPS ${DEPS} ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf snappystream snappy z + glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) else() set(DEPS ${DEPS} @@ -120,7 +122,7 @@ endif(NOT WIN32) if(WITH_GPU) if(NOT WIN32) - if (USE_TENSORRT) + if (USE_TENSORRT) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) endif() diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 67994aad70a40c0e0c8a311914d4ea40b96eaf1e..340e84d9312c20e2d10eb4c0a13066511d5d2eb5 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -16,12 +16,12 @@ if [ $2 == ON ]; then fi if [ $3 == ON ]; then use_gpu_list='true false' -else +else use_gpu_list='false' fi USE_TENSORRT=OFF -if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then +if [ -d "$TENSORRT_INCLUDE_DIR" -a -d "$TENSORRT_LIB_DIR" ]; then USE_TENSORRT=ON fi @@ -60,7 +60,8 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_MKL=$TURN_ON_MKL \ -DDEMO_NAME=simple_on_word2vec \ -DWITH_GPU=$TEST_GPU_CPU \ - -DWITH_STATIC_LIB=$WITH_STATIC_LIB + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DON_INFER=ON make -j word2vec_model=${PADDLE_ROOT}'/build/python/paddle/fluid/tests/book/word2vec.inference.model' if [ -d $word2vec_model ]; then @@ -80,10 +81,11 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_MKL=$TURN_ON_MKL \ -DDEMO_NAME=vis_demo \ -DWITH_GPU=$TEST_GPU_CPU \ - -DWITH_STATIC_LIB=$WITH_STATIC_LIB + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DON_INFER=ON make -j for use_gpu in $use_gpu_list; do - for vis_demo_name in $vis_demo_list; do + for vis_demo_name in $vis_demo_list; do ./vis_demo \ --modeldir=$DATA_DIR/$vis_demo_name/model \ --data=$DATA_DIR/$vis_demo_name/data.txt \ @@ -95,7 +97,7 @@ for WITH_STATIC_LIB in ON OFF; do fi done done - + # --------tensorrt mobilenet------ if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then rm -rf * @@ -106,8 +108,9 @@ for WITH_STATIC_LIB in ON OFF; do -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ -DUSE_TENSORRT=$USE_TENSORRT \ -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ - -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR - make -j + -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR \ + -DON_INFER=ON + make -j ./trt_mobilenet_demo \ --modeldir=$DATA_DIR/mobilenet/model \ --data=$DATA_DIR/mobilenet/data.txt \ diff --git a/paddle/fluid/inference/api/details/reset_tensor_array.cc b/paddle/fluid/inference/api/details/reset_tensor_array.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ae6c6dc9f44650c1c62f5be5448864d817513b1 --- /dev/null +++ b/paddle/fluid/inference/api/details/reset_tensor_array.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/details/reset_tensor_array.h" + +namespace paddle { +namespace details { + +// Should be called after the parameters are loaded. +void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) { + if (flag_) { + for (auto &var_name : scope->LocalVarNames()) { + auto *var = scope->FindVar(var_name); + // TODO(Superjomn) should avoid the case when a TensorArray is a + // parameter. + if (var_name == "feed" || var_name == "fetch") continue; + if (var->Type() == typeid(framework::LoDTensorArray)) { + VLOG(4) << "collect " << var_name; + arrays_.push_back(var->GetMutable()); + } + } + for (auto *kid : scope->kids()) { + CollectTensorArrays(kid); + } + + VLOG(3) << "Collect " << arrays_.size() << " arrays"; + flag_ = false; + } +} + +// Should be called when `Run` finished. +void TensorArrayBatchCleaner::ResetTensorArray() { + for (auto *arr : arrays_) { + arr->clear(); + } +} + +} // namespace details +} // namespace paddle diff --git a/paddle/fluid/inference/api/details/reset_tensor_array.h b/paddle/fluid/inference/api/details/reset_tensor_array.h new file mode 100644 index 0000000000000000000000000000000000000000..a39449ff0e67786815dfb8d2d30d79dcdba757d7 --- /dev/null +++ b/paddle/fluid/inference/api/details/reset_tensor_array.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace details { + +// Clean the TensorArray each batch to make the behavior the same with the +// training phase. +struct TensorArrayBatchCleaner { + // Fix the tensor array not clear in the inference scenarios. + void CollectTensorArrays(framework::Scope *scope); + void ResetTensorArray(); + + private: + bool flag_{true}; + std::vector arrays_; +}; + +} // namespace details +} // namespace paddle diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 24f59cf43a9700ff1732e1ef6ad82e1a6294eede..e46dc1326951f68fd030f2208b9bea1647d0026d 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid, double latency, int epoch = 1) { LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat << ", threads: " << num_threads << ", thread id: " << tid - << ", latency: " << latency << "ms ======"; + << ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f) + << " ======"; if (epoch > 1) { int samples = batch_size * epoch; LOG(INFO) << "====== sample number: " << samples diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 07ee6e72d1053d2271b8f8d69ce38003f5e038a0..a755ccb93bdee018dfeaf91157e7971b4d4cd832 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -124,7 +124,7 @@ class ZeroCopyTensor { std::vector> lod() const; protected: - ZeroCopyTensor(void* scope) : scope_{scope} {} + explicit ZeroCopyTensor(void* scope) : scope_{scope} {} void SetName(const std::string& name) { name_ = name; } void* FindTensor() const; @@ -259,12 +259,6 @@ struct AnalysisConfig : public NativeConfig { kExclude // Specify the disabled passes in `ir_passes`. }; - void SetIncludeMode() { - ir_mode = IrPassMode::kInclude; - // this pass has to be run at the beginning of all fuse passes - ir_passes = {"infer_clean_graph_pass"}; - } - // Determine whether to perform graph optimization. bool enable_ir_optim = true; // Manually determine the IR passes to run. diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index f9bb66a6e9f81a10368db7710108c319860e940a..677f85152f202b514d0563f885d872c84faba19a 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -42,16 +42,22 @@ class Pool2dOpConverter : public OpConverter { boost::get>(op_desc.GetAttr("strides")); std::vector paddings = boost::get>(op_desc.GetAttr("paddings")); + bool ceil_mode = boost::get(op_desc.GetAttr("ceil_mode")); + nvinfer1::Dims input_shape = input1->getDimensions(); + int nbDims = input_shape.nbDims; nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); + nvinfer1::DimsHW nv_strides(strides[0], strides[1]); + nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); + if (global_pooling == true) { - nvinfer1::Dims input_shape = input1->getDimensions(); - int nbDims = input_shape.nbDims; nv_ksize.d[0] = input_shape.d[nbDims - 2]; nv_ksize.d[1] = input_shape.d[nbDims - 1]; + nv_strides.h() = 1; + nv_strides.w() = 1; + nv_paddings.h() = 0; + nv_paddings.w() = 0; } - const nvinfer1::DimsHW nv_strides(strides[0], strides[1]); - const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]); PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL); @@ -64,6 +70,36 @@ class Pool2dOpConverter : public OpConverter { PADDLE_THROW("TensorRT unsupported pooling type!"); } + if (ceil_mode) { + nvinfer1::DimsHW pre_pad(0, 0); + nvinfer1::DimsHW post_pad(0, 0); + int input_height = input_shape.d[nbDims - 2]; + int input_width = input_shape.d[nbDims - 1]; + int floor_h_output_size = + (input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + int ceil_h_output_size = + (input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) / + strides[0] + + 1; + + int floor_w_output_size = + (input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + int ceil_w_output_size = + (input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) / + strides[1] + + 1; + if (floor_h_output_size != ceil_h_output_size) { + post_pad.h() = strides[0] - 1; + } + + if (floor_w_output_size != ceil_w_output_size) { + post_pad.w() = strides[1] - 1; + } + auto* layer = TRT_ENGINE_ADD_LAYER( + engine_, Padding, *const_cast(input1), pre_pad, + post_pad); + input1 = layer->getOutput(0); + } auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *const_cast(input1), nv_pool_type, nv_ksize); diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index aedd6b62df040eeee4e48f628128511cd8bf4439..ee597f8465c218c0fb6648374c128cabf7b033fb 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -20,18 +20,20 @@ namespace paddle { namespace inference { namespace tensorrt { -void test_pool2d(bool global_pooling) { +void test_pool2d(bool global_pooling, bool ceil_mode) { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(5, parameters, scope, 1 << 15); // The ITensor's Dims should not contain the batch size. // So, the ITensor's Dims of input and output should be C * H * W. - validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); + validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 13, 14)); if (global_pooling) validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); + else if (ceil_mode) + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 7)); else - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 6)); // Prepare Op description framework::OpDesc desc; @@ -39,7 +41,7 @@ void test_pool2d(bool global_pooling) { desc.SetInput("X", {"pool2d-X"}); desc.SetOutput("Out", {"pool2d-Out"}); - std::vector ksize({2, 2}); + std::vector ksize({3, 3}); std::vector strides({2, 2}); std::vector paddings({0, 0}); std::string pooling_t = "max"; @@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) { desc.SetAttr("strides", strides); desc.SetAttr("paddings", paddings); desc.SetAttr("global_pooling", global_pooling); + desc.SetAttr("ceil_mode", ceil_mode); LOG(INFO) << "set OP"; validator.SetOp(*desc.Proto()); @@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) { validator.Execute(3); } -TEST(Pool2dOpConverter, normal) { test_pool2d(false); } +TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); } +TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); } -TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); } +TEST(Pool2dOpConverter, test_ceil_mode) { test_pool2d(false, true); } } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 67668298440e9af279e792f786a8123b71172a66..c2151eea0823f80feb17b014c1f739d2a15ae862 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -71,7 +71,7 @@ void profile(bool use_mkldnn = false) { } TEST(Analyzer_resnet50, profile) { profile(); } -#ifndef PADDLE_WITH_MKLDNN +#ifdef PADDLE_WITH_MKLDNN TEST(Analyzer_resnet50, profile_mkldnn) { profile(true /* use_mkldnn */); } #endif diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index 6399476680c0af83a6d26aea952c58543bdce9ae..e0416ff953b61f56a2ca1a45cb382d40a6cffa4a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -228,6 +228,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_rnn1, profile) { contrib::AnalysisConfig cfg; SetConfig(&cfg); + cfg.use_gpu = false; std::vector outputs; std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index b1ee1080030b23e1ef7adefe3a0880f38e9099f5..19c3f532d5dcb7588793fa21fa179f6b48649103 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -50,7 +50,7 @@ void CompareResult(const std::vector &outputs, auto &ref_out = ref_outputs[i]; size_t size = VecReduceToInt(out.shape); size_t ref_size = VecReduceToInt(ref_out.shape); - EXPECT_GT(size, 0); + EXPECT_GT(size, 0UL); EXPECT_EQ(size, ref_size); EXPECT_EQ(out.dtype, ref_out.dtype); switch (out.dtype) { @@ -139,6 +139,9 @@ void TestMultiThreadPrediction( } for (int tid = 0; tid < num_threads; ++tid) { threads.emplace_back([&, tid]() { +#ifdef PADDLE_WITH_MKLDNN + platform::set_cur_thread_id(static_cast(tid) + 1); +#endif // Each thread should have local inputs and outputs. // The inputs of each thread are all the same. std::vector> inputs_tid = inputs; diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 78ef6f207eadea6799864fe22889103b468d1780..0d51cb92618170cb422cb49ba63ba54ae6608ef4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() +op_library(hash_op DEPS xxhash) op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index bbf52bea1358c32596ab6f14eeaa419735d19fc6..9ddb3a5d29f973047507855b43b226913a3600b5 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -28,7 +28,7 @@ using paddle::framework::Tensor; public: \ void Make() override { \ AddInput("X", "Input of " #OP_NAME " operator"); \ - AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \ + AddOutput("Out", "Output of " #OP_NAME " operator"); \ AddAttr("use_mkldnn", \ "(bool, default false) Only used in mkldnn kernel") \ .SetDefault(false); \ diff --git a/paddle/fluid/operators/adam_op.cc b/paddle/fluid/operators/adam_op.cc index 5d670fe3b9d99a31a628ff707ff860564eca952e..f3717af630017eba18aa265f3dbb496e18280a57 100644 --- a/paddle/fluid/operators/adam_op.cc +++ b/paddle/fluid/operators/adam_op.cc @@ -92,9 +92,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); - AddOutput("ParamOut", "(Tensor) Output parameter").Reuse("Param"); - AddOutput("Moment1Out", "(Tensor) Output first moment").Reuse("Moment1"); - AddOutput("Moment2Out", "(Tensor) Output second moment").Reuse("Moment2"); + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("Moment1Out", "(Tensor) Output first moment"); + AddOutput("Moment2Out", "(Tensor) Output second moment"); AddAttr("beta1", "(float, default 0.9) " diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 5912a1a17cbd29c3ebd83f37133c044f0905c8bd..3eb473832577bd348b33ba9b0be9e597b78f26bc 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -135,15 +135,13 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Variance", "The global variance (for training) " "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization").Reuse("X"); + AddOutput("Y", "result after normalization"); AddOutput("MeanOut", "Share memory with Mean. " - "Store the global mean when training") - .Reuse("Mean"); + "Store the global mean when training"); AddOutput("VarianceOut", "Share memory with Variance. " - "Store the global Variance when training") - .Reuse("Variance"); + "Store the global Variance when training"); AddOutput("SavedMean", "Mean of the current mini batch, " "will apply to output when training") diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index b6cb935814e25b31d4104f9ce24fe952680cb491..0d32cae0e1e5ff274793df50e854283d8e2f7bf8 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -79,6 +79,9 @@ struct BeamSearchDecodeFunctor { bool tensor_on_gpu_; size_t beam_size_; int end_id_; + // TODO(Superjomn) Here might result serious performance issue in the + // concurrency + // scenarios. const LoDTensorArray& step_ids_origin_; const LoDTensorArray& step_scores_origin_; LoDTensorArray step_ids_ = LoDTensorArray(); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 8f2561fcc389922f05093055cba4b43dbd4e4536..2cd9979bd3426a15af34a49002d5db2fdd9aeec7 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -130,8 +130,7 @@ void Conv2DOpMaker::Make() { .AsDispensable(); AddOutput("Output", "(Tensor) The output tensor of convolution operator. " - "The format of output tensor is also NCHW.") - .Reuse("Input"); + "The format of output tensor is also NCHW."); AddInput("ResidualData", "(Tensor) Tensor with residual data " "to which convolution output will be added." @@ -238,8 +237,7 @@ void Conv3DOpMaker::Make() { "input image channels divided by the groups."); AddOutput("Output", "(Tensor) The output tensor of convolution operator." - "The format of output tensor is also NCDHW.") - .Reuse("Input"); + "The format of output tensor is also NCDHW."); AddAttr>("strides", "(vector, default:{1, 1, 1}), the " "strides(d_stride, h_stride, w_stride) of " diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index a69d9c9a529f26b3981ca8d1ba226fda71b8820a..709c2dfc4b7c67d7d04074c58ce6da85b6e790fe 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -284,7 +284,7 @@ static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, selected_indices.push_back(idx); ++selected_num; } - sorted_indices.erase(sorted_indices.end()); + sorted_indices.erase(sorted_indices.end() - 1); if (flag && eta < 1 && adaptive_threshold > 0.5) { adaptive_threshold *= eta; } diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index dda423efd35b96f5e1d7c55389818f46ef3d8694..46fff9d338b7759496faaf6dd9960d34887755ba 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { PADDLE_ENFORCE( ctx->HasOutput("TargetBBox"), "Output(TargetBBox) of RpnTargetAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasOutput("BBoxInsideWeight"), + "Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null"); auto anchor_dims = ctx->GetInputDim("Anchor"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); @@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ScoreIndex", {-1}); ctx->SetOutputDim("TargetLabel", {-1, 1}); ctx->SetOutputDim("TargetBBox", {-1, 4}); + ctx->SetOutputDim("BBoxInsideWeight", {-1, 4}); } protected: @@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, const float rpn_positive_overlap, const float rpn_negative_overlap, std::vector* fg_inds, std::vector* bg_inds, std::vector* tgt_lbl, + std::vector* fg_fake, std::vector* bbox_inside_weight, std::minstd_rand engine, bool use_random) { float epsilon = 0.00001; int anchor_num = anchor_to_gt_max.dims()[0]; @@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, // Reservoir Sampling int fg_num = static_cast(rpn_fg_fraction * rpn_batch_size_per_im); ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); - fg_num = static_cast(fg_inds_fake.size()); - for (int64_t i = 0; i < fg_num; ++i) { + int fg_fake_num = static_cast(fg_inds_fake.size()); + for (int64_t i = 0; i < fg_fake_num; ++i) { target_label[fg_inds_fake[i]] = 1; } - int bg_num = rpn_batch_size_per_im - fg_num; + int bg_num = rpn_batch_size_per_im - fg_fake_num; for (int64_t i = 0; i < anchor_num; ++i) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { bg_inds_fake.push_back(i); @@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, } ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); bg_num = static_cast(bg_inds_fake.size()); + int fake_num = 0; for (int64_t i = 0; i < bg_num; ++i) { + // fg fake found + if (target_label[bg_inds_fake[i]] == 1) { + fake_num++; + fg_fake->emplace_back(fg_inds_fake[0]); + for (int j = 0; j < 4; ++j) { + bbox_inside_weight->emplace_back(T(0.)); + } + } target_label[bg_inds_fake[i]] = 0; } + for (int64_t i = 0; i < (fg_fake_num - fake_num) * 4; ++i) { + bbox_inside_weight->emplace_back(T(1.)); + } + for (int64_t i = 0; i < anchor_num; ++i) { - if (target_label[i] == 1) fg_inds->emplace_back(i); + if (target_label[i] == 1) { + fg_inds->emplace_back(i); + fg_fake->emplace_back(i); + } if (target_label[i] == 0) bg_inds->emplace_back(i); } fg_num = fg_inds->size(); @@ -248,7 +269,8 @@ std::vector SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, std::vector bg_inds; std::vector gt_inds; std::vector tgt_lbl; - + std::vector fg_fake; + std::vector bbox_inside_weight; // Calculate the max IoU between anchors and gt boxes // Map from anchor to gt box that has highest overlap auto place = ctx.GetPlace(); @@ -275,32 +297,37 @@ std::vector SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, // Follow the Faster RCNN's implementation ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap, - rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, engine, - use_random); + rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, &fg_fake, + &bbox_inside_weight, engine, use_random); int fg_num = fg_inds.size(); int bg_num = bg_inds.size(); - gt_inds.reserve(fg_num); - for (int i = 0; i < fg_num; ++i) { - gt_inds.emplace_back(argmax[fg_inds[i]]); + int fg_fake_num = fg_fake.size(); + gt_inds.reserve(fg_fake_num); + for (int i = 0; i < fg_fake_num; ++i) { + gt_inds.emplace_back(argmax[fg_fake[i]]); } - - Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t; - int* loc_index_data = loc_index_t.mutable_data({fg_num}, place); + Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t; + int* loc_index_data = loc_index_t.mutable_data({fg_fake_num}, place); int* score_index_data = score_index_t.mutable_data({fg_num + bg_num}, place); int* tgt_lbl_data = tgt_lbl_t.mutable_data({fg_num + bg_num}, place); - int* gt_inds_data = gt_inds_t.mutable_data({fg_num}, place); - std::copy(fg_inds.begin(), fg_inds.end(), loc_index_data); + int* gt_inds_data = gt_inds_t.mutable_data({fg_fake_num}, place); + T* bbox_inside_weight_data = + bbox_inside_weight_t.mutable_data({fg_fake_num, 4}, place); + std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data); std::copy(fg_inds.begin(), fg_inds.end(), score_index_data); std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num); std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data); std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data); + std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(), + bbox_inside_weight_data); std::vector loc_score_tgtlbl_gt; loc_score_tgtlbl_gt.emplace_back(loc_index_t); loc_score_tgtlbl_gt.emplace_back(score_index_t); loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t); loc_score_tgtlbl_gt.emplace_back(gt_inds_t); + loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t); return loc_score_tgtlbl_gt; } @@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel { auto* score_index = context.Output("ScoreIndex"); auto* tgt_bbox = context.Output("TargetBBox"); auto* tgt_lbl = context.Output("TargetLabel"); + auto* bbox_inside_weight = context.Output("BBoxInsideWeight"); PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, "RpnTargetAssignOp gt_boxes needs 1 level of LoD"); @@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel { score_index->mutable_data({max_num}, place); tgt_bbox->mutable_data({max_num, 4}, place); tgt_lbl->mutable_data({max_num, 1}, place); - + bbox_inside_weight->mutable_data({max_num, 4}, place); auto& dev_ctx = context.device_context(); std::random_device rnd; @@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel { Tensor sampled_score_index = loc_score_tgtlbl_gt[1]; Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2]; Tensor sampled_gt_index = loc_score_tgtlbl_gt[3]; + Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4]; int loc_num = sampled_loc_index.dims()[0]; int score_num = sampled_score_index.dims()[0]; @@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel { AppendRpns(score_index, total_score_num, &sampled_score_index_unmap); AppendRpns(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox); AppendRpns(tgt_lbl, total_score_num, &sampled_tgtlbl); + AppendRpns(bbox_inside_weight, total_loc_num * 4, + &sampled_bbox_inside_weight); total_loc_num += loc_num; total_score_num += score_num; @@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel { score_index->set_lod(loc_score); tgt_bbox->set_lod(lod_loc); tgt_lbl->set_lod(loc_score); + bbox_inside_weight->set_lod(lod_loc); loc_index->Resize({total_loc_num}); score_index->Resize({total_score_num}); tgt_bbox->Resize({total_loc_num, 4}); tgt_lbl->Resize({total_score_num, 1}); + bbox_inside_weight->Resize({total_loc_num, 4}); } }; @@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { "TargetLabel", "(Tensor), The target labels of each anchor with shape " "[F + B, 1], F and B are sampled foreground and backgroud number."); + AddOutput("BBoxInsideWeight", + "(Tensor), The bbox inside weight with shape " + "[F, 4], F is the sampled foreground number."); AddComment(R"DOC( This operator can be, for a given set of ground truth bboxes and the anchors, to assign classification and regression targets to each prediction. diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 07322e720f26213ea777be3cd22f2fead28507f0..3c28ef30922e6d6ba09b96282619eef15867631e 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" +#include namespace paddle { namespace operators { @@ -57,6 +58,29 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { "will be dropped.") .SetDefault(false); AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_prob)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * dropout_prob" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_prob )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string& type) { + PADDLE_ENFORCE( + type == "downgrade_in_infer" || type == "upscale_in_train", + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train"); + }); AddComment(R"DOC( Dropout Operator. @@ -104,7 +128,9 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel, + ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, - ops::DropoutGradKernel); + ops::DropoutGradKernel, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 1dd66e0280c46c0624ff70e822cb6fa6f06b7aa9..e011f47e086183a4ef3a3373c17acd6c21b6cf7e 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" @@ -26,7 +27,8 @@ namespace operators { template __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, - T* mask_data, T* dst) { + T* mask_data, T* dst, + bool is_upscale_in_train) { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); @@ -47,7 +49,11 @@ __global__ void RandomGenerator(const size_t n, const int seed, if (dist(rng) < dropout_prob) { mask = static_cast(0); } else { - mask = static_cast(1); + if (is_upscale_in_train) { + mask = static_cast(1.0f / (1.0f - dropout_prob)); + } else { + mask = static_cast(1); + } } dest = s * mask; mask_data[idx] = mask; @@ -67,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel { y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); @@ -83,11 +91,16 @@ class GPUDropoutKernel : public framework::OpKernel { int grid = (x->numel() + threads - 1) / threads; RandomGenerator< T><<>>( - size, seed, dropout_prob, x_data, mask_data, y_data); + size, seed, dropout_prob, x_data, mask_data, y_data, + (dropout_implementation == "upscale_in_train")); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); - Y.device(place) = X * static_cast(1.0f - dropout_prob); + if (dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } } } }; @@ -99,6 +112,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL(dropout_grad, - ops::DropoutGradKernel); + ops::GPUDropoutKernel, + ops::GPUDropoutKernel); +REGISTER_OP_CUDA_KERNEL( + dropout_grad, ops::DropoutGradKernel, + ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 0628b4b826d2730a8e3fb4842e4ae550b8c00569..6c629b7b6d255828023ed25680675ca104a33e12 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -36,6 +37,8 @@ class CPUDropoutKernel : public framework::OpKernel { auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); + auto dropout_implementation = + context.Attr("dropout_implementation"); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); @@ -49,14 +52,20 @@ class CPUDropoutKernel : public framework::OpKernel { engine.seed(seed); std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); for (size_t i = 0; i < size; ++i) { if (dist(engine) < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { - mask_data[i] = 1; - y_data[i] = x_data[i]; + if (dropout_implementation == "upscale_in_train") { + mask_data[i] = 1.0f / static_cast(1.0f - dropout_prob); + y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } } } } else { @@ -64,7 +73,11 @@ class CPUDropoutKernel : public framework::OpKernel { auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); - Y.device(place) = X * (1.0f - dropout_prob); + if (dropout_implementation == "upscale_in_train") { + Y.device(place) = X; + } else { + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } } } }; diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index 7e5975ead64ab39a9c618a33e300c4fce55a5b22..68c6e315cc3b5fa932f8946f6d4f838f4d3fc5a5 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -80,8 +80,6 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { void Make() final { AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op."); - // AddOutput("SavedShape", "(Tensor), save X, Y shape for grad to save - // memory.").AsIntermediate(); AddOutput("Out", "The output of elementwise op."); AddAttr("axis", "(int, default -1). The start dimension index " @@ -129,13 +127,11 @@ But the output only shares the LoD information with the input $X$. )DOC", GetName(), GetEquation())); - SetReuse(); } protected: virtual std::string GetName() const = 0; virtual std::string GetEquation() const = 0; - virtual void SetReuse() {} }; class ElementwiseOpGrad : public framework::OperatorWithKernel { @@ -269,7 +265,6 @@ class ElemwiseGradKernel : public framework::OpKernel { protected: \ virtual std::string GetName() const { return op_name; } \ virtual std::string GetEquation() const { return equation; } \ - virtual void SetReuse() { Reuse(__VA_ARGS__); } \ }; \ REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ __ElemwiseOp##op_type##Maker__, \ diff --git a/paddle/fluid/operators/fusion_gru_op.cc b/paddle/fluid/operators/fusion_gru_op.cc index a04c1c1263fba659e2d3f623b607e9f476bb40ed..120b2ab440156f6020fd6005dd64a48e9a6918ec 100644 --- a/paddle/fluid/operators/fusion_gru_op.cc +++ b/paddle/fluid/operators/fusion_gru_op.cc @@ -16,10 +16,9 @@ limitations under the License. */ #include // for memcpy #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/platform/cpu_info.h" namespace paddle { namespace operators { @@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel { } } -#define INIT_VEC_FUNC \ - std::function act_gate, act_state; \ - std::function cross; \ - auto& act_gate_str = ctx.Attr("gate_activation"); \ - auto& act_state_str = ctx.Attr("activation"); \ - if (platform::jit::MayIUse(platform::jit::avx)) { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_state = act_functor(act_state_str); \ - cross = math::vec_cross; \ - } else { \ - math::VecActivations act_functor; \ - act_gate = act_functor(act_gate_str); \ - act_state = act_functor(act_state_str); \ - cross = math::vec_cross; \ - } - -#define INIT_BASE_INPUT_OUTPUT \ - auto* h0 = ctx.Input("H0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - bool is_reverse = ctx.Attr("is_reverse"); - -#define INIT_BASE_SIZES \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 3D*/ \ - const int total_T = x_dims[0]; \ - const int M = x_dims[1]; \ - const int D = wh_dims[0]; \ - const int D3 = wh_dims[1]; \ - const int D2 = D * 2; +#define INIT_BASE_DEFINES \ + auto* x = ctx.Input("X"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* xx = ctx.Output("XX"); \ + auto x_lod = x->lod(); \ + auto x_dims = x->dims(); /* T x M*/ \ + auto wh_dims = wh->dims(); /* D x 3D*/ \ + const int total_T = x_dims[0]; \ + const int D3 = wh_dims[1] + +#define INIT_OTHER_DEFINES \ + auto* h0 = ctx.Input("H0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* bias = ctx.Input("Bias"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + bool is_reverse = ctx.Attr("is_reverse"); \ + const int M = x_dims[1]; \ + const int D = wh_dims[0]; \ + const int D2 = D * 2; \ + const auto& ker = math::jitkernel::KernelPool::Instance() \ + .template Get, \ + const std::string&, const std::string&>( \ + ctx.Attr("gate_activation"), \ + ctx.Attr("activation"), D); \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + auto place = ctx.GetPlace(); \ + T* xx_data = xx->mutable_data(place) void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - INIT_VEC_FUNC - - auto x_lod = x->lod(); + INIT_BASE_DEFINES; + INIT_OTHER_DEFINES; const int N = x_lod[0].size() - 1; - const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); const T* wh_state_data = wh_data + D * D2; - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); - + T* hidden_out_data = hidden_out->mutable_data(place); auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D3, M, x_data, wx_data, xx_data, @@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel { if (h0_data) { prev_hidden_data = h0_data + bid * D; } else { - // W: {W_update, W_reset; W_state} - // update gate - act_gate(D, xx_data, xx_data); - // state gate - act_state(D, xx_data + D2, xx_data + D2); - // out = a*b - blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data); - // save prev + ker->ComputeH1(xx_data, hidden_out_data); prev_hidden_data = hidden_out_data; tstart = 1; move_step(); @@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast(1), prev_hidden_data, D, wh_data, D2, static_cast(1), xx_data, D3); - act_gate(D2, xx_data, xx_data); - // rt = rt*ht_1 inplace result - blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data); - + ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data); // gemm rt * Ws blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast(1), hidden_out_data, D, wh_state_data, D, static_cast(1), xx_data + D2, D3); - act_state(D, xx_data + D2, xx_data + D2); - // out = zt*ht~ + (1-zt)*ht_1 - cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data); + ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data); // save prev prev_hidden_data = hidden_out_data; move_step(); @@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel { void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; - auto* x = ctx.Input("X"); - INIT_BASE_INPUT_OUTPUT - INIT_BASE_SIZES - if (x->lod()[0].size() == 2) { + INIT_BASE_DEFINES; + if (x_lod[0].size() == 2) { xx->Resize({total_T, D3}); SeqCompute(ctx); return; } - INIT_VEC_FUNC - + INIT_OTHER_DEFINES; auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_out = ctx.Output("BatchedOut"); - - const T* x_data = x->data(); - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - T* xx_data = xx->mutable_data(ctx.GetPlace()); - T* batched_input_data = batched_input->mutable_data(ctx.GetPlace()); - T* batched_out_data = batched_out->mutable_data(ctx.GetPlace()); - hidden_out->mutable_data(ctx.GetPlace()); - + T* batched_input_data = batched_input->mutable_data(place); + T* batched_out_data = batched_out->mutable_data(place); + hidden_out->mutable_data(place); auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); math::LoDTensor2BatchFunctor to_batch; @@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel { T* prev_hidden_data = nullptr; if (h0) { // reorder h0 - T* reordered_h0_data = reordered_h0->mutable_data(ctx.GetPlace()); + T* reordered_h0_data = reordered_h0->mutable_data(place); const T* h0_data = h0->data(); prev_hidden_data = reordered_h0_data; size_t sz = sizeof(T) * D; @@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel { T* cur_out_data = batched_out_data; // W: {W_update, W_reset; W_state} for (int i = 0; i < max_bs; ++i) { - // update gate - act_gate(D, cur_in_data, cur_in_data); - // state gate - act_state(D, cur_in_data + D2, cur_in_data + D2); - // out = a*b - blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data); + ker->ComputeH1(cur_in_data, cur_out_data); // add offset cur_in_data += D3; cur_out_data += D; @@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel { T* cur_out_data = batched_out_data; T* cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { - act_gate(D2, cur_batched_data, cur_batched_data); - // rt = rt*ht_1 inplace result - blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data); - + ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; @@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel { cur_prev_hidden_data = prev_hidden_data; for (int i = 0; i < cur_bs; ++i) { - // ht~ = act_state(...) - act_state(D, cur_batched_data + D2, cur_batched_data + D2); - // out = zt*ht~ + (1-zt)*ht_1 - cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data, - cur_out_data); - + ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data, + cur_out_data); cur_batched_data += D3; cur_prev_hidden_data += D; cur_out_data += D; @@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel { batched_out->set_lod(batched_lod); to_seq(dev_ctx, *batched_out, hidden_out); } -#undef INIT_VEC_FUNC -#undef INIT_BASE_SIZES -#undef INIT_BASE_INPUT_OUTPUT +#undef INIT_OTHER_DEFINES +#undef INIT_BASE_DEFINES }; } // namespace operators diff --git a/paddle/fluid/operators/hash_op.cc b/paddle/fluid/operators/hash_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9ebe71a3d7ae270a10a45f4805652415078b363 --- /dev/null +++ b/paddle/fluid/operators/hash_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/hash_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class HashOp : public framework::OperatorWithKernel { + public: + HashOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of HashOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of HashOp should not be null."); + + auto dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(dims.size(), 2UL, + "The input of hash_op's dimensions must be 2"); + std::vector out_dims; + out_dims.reserve(dims.size() + 1); + // copy all dims except the last one + for (size_t i = 0u; i != dims.size() - 1; ++i) { + out_dims.emplace_back(dims[i]); + } + int num_hash = ctx->Attrs().Get("num_hash"); + out_dims.emplace_back(num_hash); + // keep the last dim to 1 + out_dims.emplace_back(1); + + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class HashOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input tensor of scale operator."); + AddOutput("Out", "(Tensor) Output tensor of scale operator."); + AddComment(R"DOC( +**Hash Operator** +$$Out = scale * X$$ +)DOC"); + AddAttr("num_hash", "").SetDefault(1); + AddAttr("mod_by", "").SetDefault(100000); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker); +REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel, ops::HashKerel); diff --git a/paddle/fluid/operators/hash_op.h b/paddle/fluid/operators/hash_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9781bb0f453642cefb3eb59a05389c339a7de39d --- /dev/null +++ b/paddle/fluid/operators/hash_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +extern "C" { +#include +} +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +// template +template +class HashKerel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* out_t = context.Output("Out"); + auto* in_t = context.Input("X"); + int mod_by = context.Attr("mod_by"); + int num_hash = context.Attr("num_hash"); + auto* output = out_t->mutable_data(context.GetPlace()); + + auto in_dims = in_t->dims(); + auto in_lod = in_t->lod(); + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + auto seq_length = in_dims[0]; + auto last_dim = in_dims[in_dims.size() - 1]; + auto* input = in_t->data(); + for (int idx = 0; idx < seq_length; ++idx) { + for (int ihash = 0; ihash != num_hash; ++ihash) { + output[idx * num_hash + ihash] = + XXH64(input, sizeof(int) * last_dim, ihash) % mod_by; + } + input += last_dim; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lars_momentum_op.cc b/paddle/fluid/operators/lars_momentum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8dda93902448fa1bd21b719ffd9c9b500caf755 --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/lars_momentum_op.h" +#include "paddle/fluid/operators/momentum_op.h" + +namespace paddle { +namespace operators { + +class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", + "(LoDTensor, default LoDTensor) " + "Input parameter that has to be updated"); + AddInput("Grad", + "(LoDTensor, default LoDTensor) " + "Input gradient of the parameter"); + AddInput("Velocity", + "(LoDTensor, default LoDTensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated"); + AddInput("LearningRate", + "(LoDTensor, default LoDTensor) " + "Input learning rate"); + + AddOutput("ParamOut", + "(LoDTensor) This output is updated parameter. " + "It shared memory with Input(Param)."); + AddOutput("VelocityOut", + "(LoDTensor) This output is updated velocity. " + "It shared memory with Input(Velocity)."); + + AddAttr("mu", "(float) Momentum coefficient"); + AddAttr("lars_coeff", "(float, default 0.001) LARS coefficient.") + .SetDefault(0.001); + AddAttr("lars_weight_decay", + "(float, default 0.0005) LARS weight decay") + .SetDefault(0.0005); + + AddComment(R"DOC( +Lars Momentum Optimizer. + +This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each +weight using a local learning rate: + +$$ +local\_lr = \eta * + \frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\ +velocity = mu * velocity + + local\_lr * (grad + \beta * param) \\ +param = param - velocity. \\ +$$ + +Note that we use lars_weight_decay here to decay weights, you may need not to +use L2 regularizers in case of using LARS. + +)DOC"); + } +}; + +class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override {} +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, + paddle::framework::EmptyGradOpMaker, + ops::LarsMomentumOpVarTypeInference); +REGISTER_OP_CPU_KERNEL(lars_momentum, ops::LarsMomentumOpKernel, + ops::LarsMomentumOpKernel); diff --git a/paddle/fluid/operators/lars_momentum_op.cu b/paddle/fluid/operators/lars_momentum_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb346851a2f690fa05422c84ddcb08307539048f --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.cu @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/lars_momentum_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, + const int64_t num, const T lars_coeff, + const T lars_weight_decay, const T* p_norm, + const T* g_norm, T* p_out, T* v_out) { + T lr = learning_rate[0]; + T local_lr = learning_rate[0]; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + if (p_norm[0] > 0 && g_norm[0] > 0) { + local_lr = lr * lars_coeff * p_norm[0] / + (g_norm[0] + lars_weight_decay * p_norm[0]); + } + T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]); + v_out[i] = v_new; + p_out[i] = p[i] - v_new; + } +} + +template +class LarsMomentumOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto grad = ctx.Input("Grad"); + auto learning_rate = ctx.Input("LearningRate"); + + T* p_out = param_out->mutable_data(ctx.GetPlace()); + T* v_out = velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T lars_weight_decay = ctx.Attr("lars_weight_decay"); + + auto* p = param->data(); + auto* v = velocity->data(); + auto* g = grad->data(); + auto* lr = learning_rate->data(); + + int block = 512; + int grid = (param->numel() + block - 1) / block; + + auto eigen_p = framework::EigenVector::Flatten(*param); + auto eigen_g = framework::EigenVector::Flatten(*grad); + // calculate norms using eigein and launch the kernel. + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + auto* p_norm_data = p_norm_t.mutable_data(ctx.GetPlace()); + auto* g_norm_data = g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + + auto* place = ctx.template device_context().eigen_device(); + ep_norm.device(*place) = eigen_p.square().sum().sqrt(); + eg_norm.device(*place) = eigen_g.square().sum().sqrt(); + MomentumLarsKernel<<>>( + p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay, + p_norm_data, g_norm_data, p_out, v_out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + lars_momentum, + ops::LarsMomentumOpCUDAKernel, + ops::LarsMomentumOpCUDAKernel); diff --git a/paddle/fluid/operators/lars_momentum_op.h b/paddle/fluid/operators/lars_momentum_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e85be99fc42522e461a7915847d82144d8195a96 --- /dev/null +++ b/paddle/fluid/operators/lars_momentum_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LarsMomentumOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto learning_rate = ctx.Input("LearningRate"); + auto* grad_var = ctx.InputVar("Grad"); + // only support dense for now. + PADDLE_ENFORCE(grad_var->IsType()); + auto grad = ctx.Input("Grad"); + + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + T lars_coeff = ctx.Attr("lars_coeff"); + T lars_weight_decay = ctx.Attr("lars_weight_decay"); + + auto p_out = framework::EigenVector::Flatten(*param_out); + auto v_out = framework::EigenVector::Flatten(*velocity_out); + + auto p = framework::EigenVector::Flatten(*param); + auto v = framework::EigenVector::Flatten(*velocity); + auto g = framework::EigenVector::Flatten(*grad); + auto* lr = learning_rate->data(); + + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + p_norm_t.mutable_data(ctx.GetPlace()); + g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + + ep_norm = p.square().sum().sqrt(); + eg_norm = g.square().sum().sqrt(); + T local_lr = lr[0]; + if (ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0)); + } + v_out = v * mu + local_lr * (g + lars_weight_decay * p); + p_out = p - v_out; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index b9ac54e446811889b647397ae1fbb11c28f46777..d7f6cd5ab0acd2b677a3e5bd51bbcffe82eb1e50 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -81,6 +81,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { "Otherwise the given value indicates padding the output " "with zeros whenever lookup encounters it in Ids.") .SetDefault(kNoPadding); + // NOTE(minqiyang): grad_inplace is an temporal attribute, + // please do NOT set this attribute in python layer. + AddAttr("grad_inplace", + "(boolean, default false) " + "If the grad op reuse the input's variable.") + .SetDefault(false); AddComment(R"DOC( Lookup Table Operator. diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 58463dc4d6fd7cc3454de766814a947fee161070..e504c4f0cd5c0feaef4a251fad57b389a10a2ce7 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -68,6 +69,7 @@ class LookupTableKernel : public framework::OpKernel { const auto *table = table_t.value().data(); auto *output = output_t->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T)); @@ -75,8 +77,8 @@ class LookupTableKernel : public framework::OpKernel { PADDLE_ENFORCE_GE(ids[i], 0); auto id_index = table_t.Index(ids[i]); PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); - memcpy(output + i * row_width, table + id_index * row_width, - row_width * sizeof(T)); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); } } } @@ -111,27 +113,37 @@ class LookupTableGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); - framework::Vector new_rows; - new_rows.reserve(ids_num); - for (int64_t i = 0; i < ids_num; i++) { - new_rows.push_back(ids_data[i]); - } + std::vector new_rows; + new_rows.resize(ids_num); + std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->mutable_data(context.GetPlace()); - - d_table->set_height(table_dim[0]); - - auto *d_output_data = d_output->data(); - auto *d_table_data = d_table_value->data(); - - auto d_output_dims = d_output->dims(); - PADDLE_ENFORCE_EQ( - d_table_value->dims(), - framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); - memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + // FIXME(minqiyang): + // memory optimization will NOT reuse Tensor with SelectedRows + // so we could just share the tensor here directly. + // However, the InferVarType method will infer the output SelectedRows + // to Tensor sometimes, which is a bug, so we will add an attribute + // here to indicate the inplace and remove this attribute after + // the InferVarType's bug was fixed + bool grad_inplace = context.Attr("grad_inplace"); + if (grad_inplace) { + d_table_value->ShareDataWith(*d_output); + } else { + d_table_value->mutable_data(context.GetPlace()); + + d_table->set_height(table_dim[0]); + + auto *d_output_data = d_output->data(); + auto *d_table_data = d_table_value->data(); + + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } } else { auto *ids = context.Input("Ids"); auto *d_output = context.Input(framework::GradVarName("Out")); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 5d0c0b4228d8e2890c8b8d8bd10e0df080251350..55e2ea760158cda631ec07e2c7d318ec1cf79b77 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) +cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling) if(WITH_GPU) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) @@ -75,6 +76,6 @@ endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_library(jit_kernel - SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc + SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc DEPS cpu_info cblas) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h index 262469beea7449eb5820b86de1ac4f790a833e79..2e75b6abce5e1f43742ee15bff1dac4801186cd4 100644 --- a/paddle/fluid/operators/math/algorithm.h +++ b/paddle/fluid/operators/math/algorithm.h @@ -39,6 +39,52 @@ HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { return -1; } +template +HOSTDEVICE inline size_t LowerBound(const T *x, size_t num, const T &val) { +#ifdef __CUDA_ARCH__ + // The following code is from + // https://en.cppreference.com/w/cpp/algorithm/lower_bound + auto *first = x; + int64_t count = static_cast(num); + while (count > 0) { + int64_t step = (count >> 1); + auto *it = first + step; + if (*it < val) { + first = ++it; + count -= (step + 1); + } else { + count = step; + } + } + return static_cast(first - x); +#else + return static_cast(std::lower_bound(x, x + num, val) - x); +#endif +} + +template +HOSTDEVICE inline size_t UpperBound(const T *x, size_t num, const T &val) { +#ifdef __CUDA_ARCH__ + // The following code is from + // https://en.cppreference.com/w/cpp/algorithm/upper_bound + auto *first = x; + int64_t count = static_cast(num); + while (count > 0) { + auto step = (count >> 1); + auto *it = first + step; + if (val < *it) { + count = step; + } else { + first = ++it; + count -= (step + 1); + } + } + return static_cast(first - x); +#else + return static_cast(std::upper_bound(x, x + num, val) - x); +#endif +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index e91e4e8e5adfdfff8163efe7fc1451bc602504e0..9088d0c7a6307c3fbd9707c719ec9e6f6c85fbdb 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { const T *wp_data = nullptr) const = 0; }; +template +class GRUKernel : public Kernel { + public: + // compute h1 without h0 + virtual void ComputeH1(T *gates, T *ht) const = 0; + virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0; + virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0; +}; + } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/jit_kernel_lstm.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc similarity index 65% rename from paddle/fluid/operators/math/jit_kernel_lstm.cc rename to paddle/fluid/operators/math/jit_kernel_rnn.cc index 26bd26e2e171feea569fbd646a9caf03bebbaa46..fab293f7d03eb923995fa4cd99af955a34faa6a4 100644 --- a/paddle/fluid/operators/math/jit_kernel_lstm.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -136,6 +136,23 @@ static std::shared_ptr> GetActKernel( return nullptr; } +#ifdef __AVX__ +template +static std::unique_ptr GetAVXAct(const std::string& type) { + if (type == "sigmoid") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "relu") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "tanh") { + return std::unique_ptr(new AVXActImpl()); + } else if (type == "identity" || type == "") { + return std::unique_ptr(new AVXActImpl()); + } + PADDLE_THROW("Not support type: %s", type); + return nullptr; +} +#endif + /* LSTM JitKernel */ template class LSTMKernelImpl : public LSTMKernel { @@ -192,61 +209,49 @@ class LSTMKernelImpl : public LSTMKernel { #endif }; -#define INTRI8_FLOAT(isa) \ - template <> \ - LSTMKernelImpl::LSTMKernelImpl( \ - const std::string& act_gate, const std::string& act_cand, \ - const std::string& act_cell, int d) \ - : LSTMKernel() { \ - auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr { \ - if (type == "sigmoid") { \ - return std::unique_ptr(new AVXActImpl()); \ - } else if (type == "relu") { \ - return std::unique_ptr(new AVXActImpl()); \ - } else if (type == "tanh") { \ - return std::unique_ptr(new AVXActImpl()); \ - } else if (type == "identity" || type == "") { \ - return std::unique_ptr(new AVXActImpl()); \ - } \ - PADDLE_THROW("Not support type: %s", type); \ - }; \ - avx_act_gate_ = GetAVXAct(act_gate); \ - avx_act_cand_ = GetAVXAct(act_cand); \ - avx_act_cell_ = GetAVXAct(act_cell); \ - } \ - template <> \ - void LSTMKernelImpl::ComputeCtHt( \ - float* gates, const float* ct_1, float* ct, float* ht, \ - const float* wp_data, float* checked) const { \ - /* gates: W_ch, W_ih, W_fh, W_oh */ \ - __m256 c, i, f, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - f = _mm256_loadu_ps(gates + 16); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ - i = _mm256_loadu_ps(ct_1); \ - f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ - f = _mm256_add_ps(c, f); \ - _mm256_storeu_ps(ct, f); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ - } \ - template <> \ - void LSTMKernelImpl::ComputeC1H1( \ - float* gates, float* ct, float* ht, const float* wp_data) const { \ - __m256 c, i, o; \ - c = _mm256_loadu_ps(gates); \ - i = _mm256_loadu_ps(gates + 8); \ - o = _mm256_loadu_ps(gates + 24); \ - /* C_t = igated * cgated*/ \ - c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ - _mm256_storeu_ps(ct, c); \ - /* H_t = act_cell(C_t) * ogated */ \ - o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \ - _mm256_storeu_ps(ht, o); \ +#define INTRI8_FLOAT(isa) \ + template <> \ + LSTMKernelImpl::LSTMKernelImpl( \ + const std::string& act_gate, const std::string& act_cand, \ + const std::string& act_cell, int d) \ + : LSTMKernel() { \ + avx_act_gate_ = GetAVXAct(act_gate); \ + avx_act_cand_ = GetAVXAct(act_cand); \ + avx_act_cell_ = GetAVXAct(act_cell); \ + } \ + template <> \ + void LSTMKernelImpl::ComputeCtHt( \ + float* gates, const float* ct_1, float* ct, float* ht, \ + const float* wp_data, float* checked) const { \ + /* gates: W_ch, W_ih, W_fh, W_oh */ \ + __m256 c, i, f, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + f = _mm256_loadu_ps(gates + 16); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ + i = _mm256_loadu_ps(ct_1); \ + f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ + f = _mm256_add_ps(c, f); \ + _mm256_storeu_ps(ct, f); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ + } \ + template <> \ + void LSTMKernelImpl::ComputeC1H1( \ + float* gates, float* ct, float* ht, const float* wp_data) const { \ + __m256 c, i, o; \ + c = _mm256_loadu_ps(gates); \ + i = _mm256_loadu_ps(gates + 8); \ + o = _mm256_loadu_ps(gates + 24); \ + /* C_t = igated * cgated*/ \ + c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \ + _mm256_storeu_ps(ct, c); \ + /* H_t = act_cell(C_t) * ogated */ \ + o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \ + _mm256_storeu_ps(ht, o); \ } // TODO(TJ): optimize keq16 @@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, #undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_KEY_LSTM #undef JITKERNEL_NEW_LSTM_IMPL + +/* GRU JitKernel */ +template +class GRUKernelImpl : public GRUKernel { + public: + explicit GRUKernelImpl(const std::string& act_gate, + const std::string& act_state, int d) + : GRUKernel() { + d_ = d; + d2_ = d * 2; + act_gate_d2_ = GetActKernel(act_gate, d2_); + act_gate_d_ = GetActKernel(act_gate, d); + act_state_d_ = GetActKernel(act_state, d); + vmul_d_ = KernelPool::Instance().template Get>(d); + } + + void ComputeH1(T* gates, T* ht) const override { + act_gate_d_->Compute(gates, gates); + act_state_d_->Compute(gates + d2_, gates + d2_); + vmul_d_->Compute(gates, gates + d2_, ht); + } + + void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override { + // W: {W_update, W_reset; W_state} + act_gate_d2_->Compute(gates, gates); + vmul_d_->Compute(ht_1, gates + d_, ht); + } + + void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override { + T* y = gates + d2_; + act_state_d_->Compute(y, y); + // out = zt*ht~ + (1-zt)*ht_1 + for (int i = 0; i < d_; ++i) { + ht[i] = gates[i] * y[i] + (static_cast(1) - gates[i]) * ht_1[i]; + } + } + + private: + int d_, d2_; + std::shared_ptr> act_gate_d2_, act_gate_d_, act_state_d_; + std::shared_ptr> vmul_d_; +#ifdef __AVX__ + std::unique_ptr avx_act_gate_, avx_act_state_; +#endif +}; + +#define INTRI8_FLOAT(isa) \ + template <> \ + GRUKernelImpl::GRUKernelImpl( \ + const std::string& act_gate, const std::string& act_state, int d) \ + : GRUKernel() { \ + avx_act_gate_ = GetAVXAct(act_gate); \ + avx_act_state_ = GetAVXAct(act_state); \ + } \ + template <> \ + void GRUKernelImpl::ComputeH1(float* gates, float* ht) \ + const { \ + __m256 u, s; \ + /* W: {W_update, W_reset; W_state} */ \ + u = _mm256_loadu_ps(gates); \ + s = _mm256_loadu_ps(gates + 16); \ + s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \ + _mm256_storeu_ps(ht, s); \ + } \ + template <> \ + void GRUKernelImpl::ComputeHtPart1( \ + float* gates, const float* ht_1, float* ht) const { \ + /* not exactly equal the any implementation */ \ + __m256 r, ht0; \ + r = _mm256_loadu_ps(gates + 8); \ + ht0 = _mm256_loadu_ps(ht_1); \ + r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \ + _mm256_storeu_ps(ht, r); \ + } \ + template <> \ + void GRUKernelImpl::ComputeHtPart2( \ + float* gates, const float* ht_1, float* ht) const { \ + /* not exactly equal the any implementation */ \ + __m256 u, s, ht0; \ + u = _mm256_loadu_ps(gates); \ + s = _mm256_loadu_ps(gates + 16); \ + ht0 = _mm256_loadu_ps(ht_1); \ + u = avx_act_gate_->Compute(u); \ + s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \ + u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \ + u = _mm256_mul_ps(u, ht0); \ + u = _mm256_add_ps(s, u); \ + _mm256_storeu_ps(ht, u); \ + } + +#ifdef __AVX__ +INTRI8_FLOAT(jit::avx); +#endif +#ifdef __AVX2__ +INTRI8_FLOAT(jit::avx2); +#endif +#ifdef __AVX512F__ +INTRI8_FLOAT(jit::avx512f); +#endif + +#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \ + template <> \ + std::shared_ptr> KernelPool::Get< \ + GRUKernel, const std::string&, const std::string&, int>( \ + const std::string& act_gate, const std::string& act_state, int d) + +#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \ + #ker_key #dtype_key + std::to_string(d) + act_gate + act_state + +#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \ + p = std::dynamic_pointer_cast>( \ + std::make_shared>(act_gate, act_state, d)); + +REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU, + JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL); + +#undef INTRI8_FLOAT +#undef JITKERNEL_NEW_GRU_IMPL +#undef JITKERNEL_KEY_GRU +#undef JITKERNEL_DECLARE_GRU } // namespace jitkernel } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 235b5405fb7d016f4bd8c738f75b303522183116..7be8539a7b0f1890898fd386a3056601fda8a7c3 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -157,6 +157,31 @@ class FirstSeqPoolFunctor { } }; +template +class SumSeqPoolGradFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& out_grad, + framework::LoDTensor* in_grad) { + auto lod = in_grad->lod()[0]; + int64_t out_w = out_grad.numel() / out_grad.dims()[0]; + int64_t in_w = in_grad->numel() / in_grad->dims()[0]; + PADDLE_ENFORCE(in_w == out_w); + const T* out_g_data = out_grad.data(); + T* in_g_data = in_grad->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + int64_t h = static_cast(lod[i + 1] - lod[i]); + int64_t in_offset = lod[i] * in_w; + const T* out_pos = out_g_data + i * out_w; + T* in_pos = in_g_data + in_offset; + for (int r = 0; r != h; ++r) { + blas.VCOPY(in_w, out_pos, in_pos + r * in_w); + } + } + } +}; + template class SequencePoolFunctor { public: @@ -231,9 +256,15 @@ class SequencePoolGradFunctor { math::SetConstant functor; functor(context, in_grad, 0); } + + if (pooltype == "SUM") { + math::SumSeqPoolGradFunctor sum_pool_grad; + sum_pool_grad(context, out_grad, in_grad); + return; + } + auto lod = in_grad->lod()[0]; auto& place = *context.eigen_device(); - auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { auto in_g_t = in_grad->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -247,12 +278,6 @@ class SequencePoolGradFunctor { if (pooltype == "AVERAGE") { in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); - } else if (pooltype == "SUM") { - const T* out_g_data = out_g_t.data(); - T* in_g_data = in_g_t.mutable_data(context.GetPlace()); - for (int r = 0; r != h; ++r) { - blas.VCOPY(w, out_g_data, in_g_data + r * w); - } } else if (pooltype == "SQRT") { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); diff --git a/paddle/fluid/operators/math/sequence_pooling_test.cc b/paddle/fluid/operators/math/sequence_pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bc008dd34ffcfe93a00bd4a8cde61626d91e235 --- /dev/null +++ b/paddle/fluid/operators/math/sequence_pooling_test.cc @@ -0,0 +1,126 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/sequence_pooling.h" +#include +#include + +template +void TestSequencePoolingSum(const paddle::framework::LoD& lod) { + paddle::framework::LoDTensor cpu_out_grad; + paddle::framework::LoDTensor cpu_in_grad; + paddle::framework::LoDTensor out_grad; + paddle::framework::LoDTensor in_grad; + const size_t second_dim = 128u; + + // construct out_grad's tensor in cpu + const size_t out_first_dim = lod[0].size() - 1; + auto out_dims = paddle::framework::make_ddim( + {static_cast(out_first_dim), static_cast(second_dim)}); + + cpu_out_grad.mutable_data(out_dims, paddle::platform::CPUPlace()); + for (int64_t i = 0; i < cpu_out_grad.numel(); ++i) { + cpu_out_grad.data()[i] = static_cast(i); + } + + // copy to dst out_grad + auto* place = new Place(); + DeviceContext* context = new DeviceContext(*place); + if (paddle::platform::is_cpu_place(*place)) { + out_grad = cpu_out_grad; + } else { + TensorCopySync(cpu_out_grad, *place, &out_grad); + } + + // construct in_grad + in_grad.set_lod(lod); + auto in_dims = paddle::framework::make_ddim( + {static_cast(lod[0].back()), static_cast(second_dim)}); + in_grad.mutable_data(in_dims, context->GetPlace()); + + // check tensor contruction result + PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size()); + for (int64_t i = 1; i < out_grad.dims().size(); ++i) { + PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]); + } + + // call functor + paddle::operators::math::SequencePoolGradFunctor()( + *context, "SUM", out_grad, &in_grad); + + if (paddle::platform::is_cpu_place(*place)) { + cpu_in_grad = in_grad; + } else { + TensorCopySync(in_grad, paddle::platform::CPUPlace(), &cpu_in_grad); + cpu_in_grad.set_lod(in_grad.lod()); + } + + EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim); + EXPECT_EQ(in_grad.lod(), lod); + + if (paddle::platform::is_cpu_place(*place)) { + for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) { + int64_t begin = in_grad.lod()[0][i]; + int64_t end = in_grad.lod()[0][i + 1]; + paddle::framework::Tensor tmp = in_grad.Slice(begin, end); + for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) { + for (int64_t m = 0; m != second_dim; ++m) { + EXPECT_EQ(tmp.data()[m + j * second_dim], + out_grad.data()[m + i * second_dim]); + } + } + } + } else { + for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) { + int64_t begin = cpu_in_grad.lod()[0][i]; + int64_t end = cpu_in_grad.lod()[0][i + 1]; + paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end); + for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) { + for (int64_t m = 0; m != second_dim; ++m) { + EXPECT_EQ(tmp.data()[m + j * second_dim], + cpu_out_grad.data()[m + i * second_dim]); + } + } + } + } + + delete place; + delete context; +} + +TEST(SequencePoolingGrad, CPU_SUM) { + paddle::framework::LoD lod1; + lod1.push_back(std::vector{0, 10}); + TestSequencePoolingSum(lod1); + + paddle::framework::LoD lod2; + lod2.push_back(std::vector{0, 2, 7, 10}); + TestSequencePoolingSum(lod2); +} + +#ifdef PADDLE_WITH_CUDA +TEST(SequencePoolingGrad, CUDA_SUM) { + paddle::framework::LoD lod1; + lod1.push_back(std::vector{0, 10}); + TestSequencePoolingSum(lod1); + + paddle::framework::LoD lod2; + lod2.push_back(std::vector{0, 2, 7, 10}); + TestSequencePoolingSum(lod2); +} +#endif diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 9e0bebd17c02a3ce010b77142757b8789cfbcdd9..19426b3c204095bd415cebcd87cff18468acd564 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of mean op"); - AddOutput("Out", "(Tensor) The output of mean op").Reuse("X"); + AddOutput("Out", "(Tensor) The output of mean op"); AddComment(R"DOC( Mean Operator calculates the mean of all elements in X. diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 12b916fcebd425bd4a03d920f947829098a924a1..7f0b51580aa2591ac7338ad7c29ee4756d909925 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -19,54 +19,6 @@ namespace operators { using Tensor = framework::Tensor; -class MomentumOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(param) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(grad) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Velocity"), - "Input(velocity) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of Momentum should not be null."); - PADDLE_ENFORCE( - ctx->GetInputsVarType("Param").front() == - framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); - - PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), - "Output(ParamOut) of Momentum should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), - "Output(VelocityOut) of Momentum should not be null."); - - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - "Param and Grad input of MomentumOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Velocity"), - "Param and Velocity of MomentumOp should have the same dimension."); - } - PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, - "Learning_rate should be a scalar"); - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("VelocityOut", param_dim); - } - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - } -}; - class MomentumOpInferVarType : public framework::VarTypeInference { public: void operator()(const framework::OpDesc& op_desc, diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 6b4d00f56ca06c402c07ecf770a390e88ae3edf1..71f079e4d97f5259359ee6572f584894551452ca 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -28,6 +28,54 @@ using framework::SelectedRows; struct NoNesterov; struct UseNesterov; +class MomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(param) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(grad) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Velocity"), + "Input(velocity) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of Momentum should not be null."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("Param").front() == + framework::proto::VarType::LOD_TENSOR, + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of Momentum should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"), + "Output(VelocityOut) of Momentum should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad input of MomentumOp should have the same dimension."); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Velocity"), + "Param and Velocity of MomentumOp should have the same dimension."); + } + PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1, + "Learning_rate should be a scalar"); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("VelocityOut", param_dim); + } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + template class CPUDenseMomentumFunctor { private: diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index f8ad63690e84339da0390d4ddd2db45f25db385a..24a5346b031008531fcefff0e6f1c31da33d1c3b 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -151,8 +151,7 @@ void Pool2dOpMaker::Make() { "The format of output tensor is also NCHW, " "where N is batch size, C is the number of channels, " "H is the height of the feature, " - "and W is the width of the feature.") - .Reuse("X"); + "and W is the width of the feature."); AddAttr("pooling_type", "(string), pooling type, can be \"max\" for max-pooling " @@ -252,8 +251,7 @@ void Pool3dOpMaker::Make() { "The format of output tensor is also NCDHW, " "where N is batch size, C is " "the number of channels, and D, H and W is the depth, height and " - "width of the feature, respectively.") - .Reuse("X"); + "width of the feature, respectively."); AddAttr("pooling_type", "(string) Pooling type, can be \"max\" for max-pooling " diff --git a/paddle/fluid/operators/reader/reader_blocking_queue_test.cc b/paddle/fluid/operators/reader/reader_blocking_queue_test.cc index 8cd505806056f1af33712e2c92b7661d87485708..dc0940ac0b78d295b5088cb6ae26300da1dc883d 100644 --- a/paddle/fluid/operators/reader/reader_blocking_queue_test.cc +++ b/paddle/fluid/operators/reader/reader_blocking_queue_test.cc @@ -237,7 +237,7 @@ TEST(BlockingQueue, speed_test_mode) { } for (size_t i = 0; i < queue_size; ++i) { q2.Receive(&b); - EXPECT_EQ(b, 0); + EXPECT_EQ(b, 0UL); } EXPECT_EQ(q2.Size(), queue_size); } diff --git a/paddle/fluid/operators/sequence_reverse_op.cc b/paddle/fluid/operators/sequence_reverse_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1428cca1a6bf6150594f9cb72dbf00cd0eff7df5 --- /dev/null +++ b/paddle/fluid/operators/sequence_reverse_op.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_reverse_op.h" + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(sequence_reverse, ops::SequenceReverseOp, + ops::SequenceReverseOpMaker, + ops::SequenceReverseGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL( + sequence_reverse, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel); diff --git a/paddle/fluid/operators/sequence_reverse_op.cu b/paddle/fluid/operators/sequence_reverse_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ce65f4799e8661adca60d212eaa9c3f0f92c4c29 --- /dev/null +++ b/paddle/fluid/operators/sequence_reverse_op.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_reverse_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + sequence_reverse, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel, + ops::SequenceReverseOpKernel); diff --git a/paddle/fluid/operators/sequence_reverse_op.h b/paddle/fluid/operators/sequence_reverse_op.h new file mode 100644 index 0000000000000000000000000000000000000000..39dad2311b2bcf29f808723caf7bfaef4c88cef2 --- /dev/null +++ b/paddle/fluid/operators/sequence_reverse_op.h @@ -0,0 +1,157 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +class SequenceReverseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); + + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dim.size(), 2, + "Rank of Input(X) must be not less than 2."); + + ctx->SetOutputDim("Y", x_dim); + ctx->ShareLoD("X", "Y"); + } +}; + +class SequenceReverseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input LoDTensor of sequence_reverse op."); + AddOutput("Y", "The output LoDTensor of sequence_reverse op."); + AddComment(R"DOC( +SequenceReverse Operator. + +Reverse each sequence in input X along dim 0. + +Assuming X is a LoDTensor with dims [5, 4] and lod [[0, 2, 5]], where: + +X.data() = [ + [1, 2, 3, 4], + [5, 6, 7, 8], # the 0-th sequence with length 2 + [9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20] # the 1-st sequence with length 3 +] + +The output Y would be a LoDTensor sharing the same dims and lod with input X, +and: + +Y.data() = [ + [5, 6, 7, 8], + [1, 2, 3, 4], # the reversed 0-th sequence with length 2 + [17, 18, 19, 20], + [13, 14, 15, 16], + [9, 10, 11, 12] # the reversed 1-st sequence with length 3 +] + +This Operator is useful to build a reverse dynamic RNN network. + +This Operator only supports one-level lod currently. + )DOC"); + } +}; + +template +struct SequenceReverseFunctor { + SequenceReverseFunctor(const T *x, T *y, const size_t *lod, size_t lod_count, + size_t row_numel) + : x_(x), y_(y), lod_(lod), lod_count_(lod_count), row_numel_(row_numel) {} + + HOSTDEVICE void operator()(size_t idx_x) const { + auto row_idx_x = idx_x / row_numel_; + auto lod_idx = math::UpperBound(lod_, lod_count_, row_idx_x); + auto row_idx_y = lod_[lod_idx - 1] + (lod_[lod_idx] - 1 - row_idx_x); + auto idx_y = row_idx_y * row_numel_ + idx_x % row_numel_; + y_[idx_y] = x_[idx_x]; + } + + const T *x_; + T *y_; + const size_t *lod_; + size_t lod_count_; + size_t row_numel_; +}; + +template +class SequenceReverseOpKernel : public framework::OpKernel { + using LoDTensor = framework::LoDTensor; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &x = *ctx.Input("X"); + auto *y = ctx.Output("Y"); + + PADDLE_ENFORCE_EQ(x.lod().size(), 1, + "SequenceReverse Op only support one level lod."); + + auto &dev_ctx = ctx.template device_context(); + const size_t *lod; + size_t lod_count = x.lod()[0].size(); + +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + lod = x.lod()[0].CUDAData(ctx.GetPlace()); + } else { +#endif + lod = x.lod()[0].data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + + size_t limit = static_cast(x.numel()); + size_t row_numel = static_cast(limit / x.dims()[0]); + auto *x_data = x.data(); + auto *y_data = y->mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_NE(x_data, y_data, + "SequenceReverse Op does not support in-place operation"); + + SequenceReverseFunctor functor(x_data, y_data, lod, lod_count, + row_numel); + platform::ForRange for_range(dev_ctx, limit); + for_range(functor); + } +}; + +class SequenceReverseGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("sequence_reverse"); + op->SetInput("X", OutputGrad("Y")); + op->SetOutput("Y", InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/sgd_op.cc index 411a126bc8e2b3a8d25f436489c13970568ccae4..ea62acd08c5009556abf05c91726111870d1a462 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/sgd_op.cc @@ -77,8 +77,7 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddOutput("ParamOut", "(Tensor or SelectedRows, same with Param) " - "Output parameter, should share the same memory with Param") - .Reuse("Param"); + "Output parameter, should share the same memory with Param"); AddComment(R"DOC( SGD operator diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 2bdb23e999621b10799b5163f326bc4b66a437e6..f6e241af0634650f4a32be6a4547617f8ec3ee60 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -76,6 +76,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ops::SoftmaxCUDNNKernel, + ops::SoftmaxCUDNNKernel, ops::SoftmaxCUDNNKernel); REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, - ops::SoftmaxGradCUDNNKernel); + ops::SoftmaxGradCUDNNKernel, + ops::SoftmaxGradCUDNNKernel); diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index bb081238820b9ee3ae095442d21cfce11f7b41e5..a4bdbe6648afa7c91a056af4737bb5d826229022 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -80,8 +80,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of softmax, " "whose last dimension is the input_feature_dimensions."); - AddOutput("Out", "The normalized values with the same shape as X.") - .Reuse("X"); + AddOutput("Out", "The normalized values with the same shape as X."); AddAttr( "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index fe7c7039c7dec714e265ede1b7167fd800ddc2f7..34dbac2ab8dcc9bd2b91e2daa2f42806057f5f56 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -132,7 +132,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "(vector) The input tensors of sum operator.") .AsDuplicable(); - AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X"); + AddOutput("Out", "(Tensor) The output tensor of sum operator."); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 4a8ac441cfaf642fde58ee30865a22e83c065498..c17d1afc309c65035063348d4934ea1783b018ed 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Topk op"); - AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X"); + AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddComment(R"DOC( Top K operator diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 8e4a07556fb51dbb15ef948fcee120e2f68e089a..0cad224ca8860b0e4bc2e3f2bc1659235aadfe2d 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, const T* src, int lds, int dim, int k, int grid_dim, int num) { __shared__ Pair sh_topk[BlockSize]; - __shared__ int maxid[BlockSize / 2]; const int tid = threadIdx.x; const int warp = threadIdx.x / 32; const int bid = blockIdx.x; for (int i = bid; i < num; i += grid_dim) { - output += i * output_stride; - indices += i * k; - + int top_num = k; + __shared__ int maxid[BlockSize / 2]; + T* out = output + i * output_stride; + int64_t* inds = indices + i * k; Pair topk[MaxLength]; int beam = MaxLength; Pair max; bool is_empty = false; bool firststep = true; - for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + for (int j = 0; j < MaxLength; j++) { + topk[j].set(-INFINITY, -1); } - while (k) { + while (top_num) { ThreadGetTopK( topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, maxid, topk, &output, - &indices, &beam, &k, tid, warp); + BlockReduce(sh_topk, maxid, topk, &out, &inds, + &beam, &top_num, tid, warp); } } } @@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel { size_t k = static_cast(ctx.Attr("k")); const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - size_t input_height = input->dims()[0]; - size_t input_width = input->dims()[1]; + framework::DDim inputdims = input->dims(); + const size_t input_height = framework::product( + framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); + const size_t input_width = inputdims[inputdims.size() - 1]; + if (k > input_width) k = input_width; // NOTE: pass lds and dim same to input width. @@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel { const int kMaxHeight = 2048; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; auto& dev_ctx = ctx.cuda_device_context(); - switch (GetDesiredBlockDim(input_width)) { FIXED_BLOCK_DIM( KeMatrixTopK<<>>( - output_data, output->dims()[1], indices_data, input_data, - input_width, input_width, static_cast(k), gridx, - input_height)); + output_data, k, indices_data, input_data, input_width, + input_width, static_cast(k), gridx, input_height)); default: PADDLE_THROW("Error"); } diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index 054dd481994d03f71b0ed5dc73e103085f6c91aa..76ece57b39919148da04caecaa43ea9d2b9d95df 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { // Get the top k elements of each row of input tensor - // FIXME: only deal with matrix(2d tensor). auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); @@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel { T* output_data = output->mutable_data(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - auto eg_input = EigenMatrix::From(*input); - // reshape input to a flattern matrix(like flat_inner_dims) framework::DDim inputdims = input->dims(); const size_t row = framework::product( @@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel { const size_t col = inputdims[inputdims.size() - 1]; Eigen::DSizes flat2dims(row, col); // NOTE: eigen shape doesn't affect paddle tensor. - eg_input.reshape(flat2dims); + auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 6a9fc6611a8f8eaa6749aefac0673ccabaebbcfe..bbd71db6062107f6ba40343c84d942b54b3958e6 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -210,18 +210,21 @@ REGISTER_OPERATOR(transpose, ops::TransposeOp, ops::TransposeOpMaker, REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad); REGISTER_OP_CPU_KERNEL( - transpose, ops::TransposeKernel); + transpose, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CPU_KERNEL( transpose_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker, ops::Transpose2GradMaker); REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad); REGISTER_OP_CPU_KERNEL( - transpose2, - ops::TransposeKernel); + transpose2, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CPU_KERNEL( transpose2_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/paddle/fluid/operators/transpose_op.cu.cc b/paddle/fluid/operators/transpose_op.cu.cc index c1b5a8b31be243fab3af06a18c8e51986c953700..b4025350fa9f3610bde43eee91cd059f3063813f 100644 --- a/paddle/fluid/operators/transpose_op.cu.cc +++ b/paddle/fluid/operators/transpose_op.cu.cc @@ -16,15 +16,18 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - transpose, - ops::TransposeKernel); + transpose, ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); REGISTER_OP_CUDA_KERNEL( transpose2, - ops::TransposeKernel); + ops::TransposeKernel, + ops::TransposeKernel); REGISTER_OP_CUDA_KERNEL( transpose2_grad, - ops::TransposeGradKernel); + ops::TransposeGradKernel, + ops::TransposeGradKernel); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7d1cf57253819b34fedfb292ad1635650f53f20f..b0de636de46451c8b05546fdbff142f984c2bb43 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } #ifdef PADDLE_WITH_MKLDNN MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) - : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { - p_blobs_.reset(new std::unordered_map>()); + : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() { + p_blobmap_.reset(new BlobMap()); + p_mutex_.reset(new std::mutex()); } +namespace { +// Current thread's id. +thread_local int cur_thread_id = 0; +} + +void set_cur_thread_id(int tid) { cur_thread_id = tid; } +int get_cur_thread_id(void) { return cur_thread_id; } + void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { - std::unordered_map>* p; - p = p_blobs_.get(); + BlobMap* pMap = p_blobmap_.get(); + std::shared_ptr pBlob = nullptr; + + int tid = platform::get_cur_thread_id(); - auto it = p->find(name); + std::lock_guard lock(*p_mutex_.get()); - if (it == p->end()) { - (*p)[name] = data; // create new blob + // Find KeyBlob for current thread + auto map_it = pMap->find(tid); + + if (map_it == pMap->end()) { + // 1st time to set blob in current thread + pBlob = std::shared_ptr(new KeyBlob()); + (*pMap)[tid] = pBlob; } else { - it->second = data; // set data to existing blob + pBlob = map_it->second; } + // Find Key in found (or newly created) KeyBlob + auto key_it = pBlob->find(name); + + if (key_it == pBlob->end()) { + (*pBlob)[name] = data; // create new blob + } else { + key_it->second = data; // set data to existing blob + } + + // lock will be automatically released when out of scope return; } std::shared_ptr MKLDNNDeviceContext::GetBlob( const std::string& name) const { - std::unordered_map>* p; - p = p_blobs_.get(); + BlobMap* pMap = p_blobmap_.get(); + std::shared_ptr pBlob = nullptr; - auto it = p->find(name); + int tid = platform::get_cur_thread_id(); - if (it != p->end()) { - return it->second; - } + std::lock_guard lock(*p_mutex_.get()); + + // Find KeyBlob for current thread firstly + auto map_it = pMap->find(tid); + if (map_it == pMap->end()) return nullptr; + pBlob = map_it->second; + + // Find Blob via name + auto key_it = pBlob->find(name); + + if (key_it == pBlob->end()) return nullptr; - return nullptr; + // lock will be automatically released when out of scope + return key_it->second; } #endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 999bbe00f1659881050cb0dc89570b74b201aca7..942e13a724339dc85ed1fc72c11e208ddce36dbb 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -176,6 +176,12 @@ struct DefaultDeviceContextType { #endif #ifdef PADDLE_WITH_MKLDNN +using KeyBlob = std::unordered_map>; +using BlobMap = std::unordered_map>; + +void set_cur_thread_id(int); +int get_cur_thread_id(void); + class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); @@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext { private: mkldnn::engine engine_; - std::shared_ptr>> - p_blobs_; + std::shared_ptr p_blobmap_; + std::shared_ptr p_mutex_; }; #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 339a7c98c6a2bba2cd46790cecc169ef447c63ce..5f15a29f4c3e9b1412912fe4723642d1ede60346 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -645,9 +645,13 @@ All parameter, weight, gradient are variables in Paddle. py::class_> pass(m, "Pass"); pass.def(py::init()) - .def("set_str", [](ir::Pass &self, const std::string &name, - const std::string &attr) { - self.Set(name, new std::string(attr)); + .def( + "set_str", + [](ir::Pass &self, const std::string &name, const std::string &attr) { + self.Set(name, new std::string(attr)); + }) + .def("set_int", [](ir::Pass &self, const std::string &name, int val) { + self.Set(name, new int(val)); }); py::class_> pb( diff --git a/paddle/fluid/train/demo/CMakeLists.txt b/paddle/fluid/train/demo/CMakeLists.txt index 78d6e5ff554b9cd9facae85be166a697e0b75337..eabb51d370aff709e289e1fc727aa2dbb551d82e 100644 --- a/paddle/fluid/train/demo/CMakeLists.txt +++ b/paddle/fluid/train/demo/CMakeLists.txt @@ -15,6 +15,7 @@ include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") include_directories("${PADDLE_LIB}/third_party/install/gflags/include") +include_directories("${PADDLE_LIB}/third_party/install/xxhash/include") include_directories("${PADDLE_LIB}/third_party/install/snappy/include") include_directories("${PADDLE_LIB}/third_party/install/snappystream/include") include_directories("${PADDLE_LIB}/third_party/install/zlib/include") @@ -27,6 +28,7 @@ link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib") link_directories("${PADDLE_LIB}/third_party/install/glog/lib") link_directories("${PADDLE_LIB}/third_party/install/gflags/lib") +link_directories("${PADDLE_LIB}/third_party/install/xxhash/lib") link_directories("${PADDLE_LIB}/third_party/install/zlib/lib") add_executable(demo_trainer demo_trainer.cc) @@ -62,5 +64,5 @@ target_link_libraries(demo_trainer ${ARCHIVE_END} ${MATH_LIB} ${MKLDNN_LIB} - glog gflags protobuf snappystream snappy z + glog gflags protobuf snappystream snappy z xxhash ${EXTERNAL_LIB}) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 85493c10549c290330ed09b9f28accb7a980de6a..5a71382fb14b64989502c34d8ac0aa13c62bc7d0 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -95,9 +95,9 @@ function cmake_gen() { exit 1 fi fi - else + else if [ "$1" != "" ]; then - echo "using python abi: $1" + echo "using python abi: $1" if [ "$1" == "cp27-cp27m" ]; then export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:} export PATH=/opt/python/cp27-cp27m/bin/:${PATH} @@ -119,7 +119,7 @@ function cmake_gen() { fi fi fi - + if [ "$SYSTEM" == "Darwin" ]; then WITH_DISTRIBUTE=${WITH_DISTRIBUTE:-ON} WITH_AVX=${WITH_AVX:-ON} @@ -127,7 +127,7 @@ function cmake_gen() { else INFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR:-/root/.cache/inference_demo} fi - + cat <>> with program._optimized_guard([p,g]): >>> p = p - 0.001 * g """ + tmp_role = self._current_role + tmp_var = self._op_role_var + OpRole = core.op_proto_and_checker_maker.OpRole self._current_role = OpRole.Optimize self._op_role_var = [ @@ -1503,11 +1506,11 @@ class Program(object): for var in param_and_grads ] yield - self._op_role_var = [] - self._current_role = OpRole.Forward + self._op_role_var = tmp_var + self._current_role = tmp_role @contextlib.contextmanager - def _lr_schedule_guard(self): + def _lr_schedule_guard(self, is_with_opt=False): """ A with guard to set :code:`LRSched` :code:`OpRole` and :code:`OpRoleVar` automatically. The :code:`OpRoleVar` is @@ -1515,6 +1518,10 @@ class Program(object): Notes: This is a very low level API. Users should not use it directly. + Args: + is_with_opt: Only set to true if these ops a in the middle + of a bunch of optimize ops so that it can be treated + correctly. For example, sgd->lr_op->sgd->lr_op->sgd. Examples: @@ -1528,6 +1535,8 @@ class Program(object): OpRole = core.op_proto_and_checker_maker.OpRole self._current_role = OpRole.LRSched + if is_with_opt: + self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize) # TODO(typhoonzero): how to set target learning rate var self._op_role_var = [] yield diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index b94b59631a3d8999f569acf1027c71d3019f5c56..ece22d0b7ed4cac6618c7be14939c770bcf1176d 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred, Returns: tuple: A tuple(predicted_scores, predicted_location, target_label, - target_bbox) is returned. The predicted_scores and - predicted_location is the predicted result of the RPN. + target_bbox, bbox_inside_weight) is returned. The predicted_scores + and predicted_location is the predicted result of the RPN. The target_label and target_bbox is the ground truth, respectively. The predicted_location is a 2D Tensor with shape [F, 4], and the shape of target_bbox is same as the shape of @@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred, [F + B, 1], and the shape of target_label is same as the shape of the predicted_scores, B is the number of the background anchors, the F and B is depends on the input of this operator. + Bbox_inside_weight represents whether the predicted loc is fake_fg + or not and the shape is [F, 4]. Examples: .. code-block:: python @@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred, append_batch_size=False, dtype='float32') gt_boxes = layers.data(name='gt_boxes', shape=[10, 4], append_batch_size=False, dtype='float32') - loc_pred, score_pred, loc_target, score_target = + loc_pred, score_pred, loc_target, score_target, bbox_inside_weight = fluid.layers.rpn_target_assign(bbox_pred=bbox_pred, cls_logits=cls_logits, anchor_box=anchor_box, @@ -152,6 +154,8 @@ def rpn_target_assign(bbox_pred, target_label = helper.create_variable_for_type_inference(dtype='int32') target_bbox = helper.create_variable_for_type_inference( dtype=anchor_box.dtype) + bbox_inside_weight = helper.create_variable_for_type_inference( + dtype=anchor_box.dtype) helper.append_op( type="rpn_target_assign", inputs={ @@ -164,7 +168,8 @@ def rpn_target_assign(bbox_pred, 'LocationIndex': loc_index, 'ScoreIndex': score_index, 'TargetLabel': target_label, - 'TargetBBox': target_bbox + 'TargetBBox': target_bbox, + 'BBoxInsideWeight': bbox_inside_weight }, attrs={ 'rpn_batch_size_per_im': rpn_batch_size_per_im, @@ -179,13 +184,14 @@ def rpn_target_assign(bbox_pred, score_index.stop_gradient = True target_label.stop_gradient = True target_bbox.stop_gradient = True + bbox_inside_weight.stop_gradient = True cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1)) bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4)) predicted_cls_logits = nn.gather(cls_logits, score_index) predicted_bbox_pred = nn.gather(bbox_pred, loc_index) - return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox + return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight def detection_output(loc, diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index dfd801a098d6451dbdb20d9ba44187d1e3f8a91a..149224bb68ac869dec14ac9f953f0072bd24c7e2 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -27,7 +27,7 @@ from . import nn from . import ops from . import tensor from ..initializer import init_on_cpu -from ..framework import default_main_program, Parameter, unique_name +from ..framework import default_main_program, Parameter, unique_name, name_scope __all__ = [ 'exponential_decay', 'natural_exp_decay', 'inverse_time_decay', @@ -332,14 +332,16 @@ def append_LARS(params_grads, learning_rate, weight_decay): return grad_norm + weight_decay * param_norm for param, grad in params_grads: - param_lr = param.optimize_attr['learning_rate'] - param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param))) - grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad))) - if type(param_lr) == float and param_lr == 1.0: - decayed_lr = learning_rate * param_norm \ - / _balanced_weight(param_norm, grad_norm) - else: - decayed_lr = learning_rate * param_lr * param_norm \ - / _balanced_weight(param_norm, grad_norm) - # set back param local learning rate - param.optimize_attr['learning_rate'] = decayed_lr + with param.block.program.optimized_guard( + [param, grad]), name_scope("optimizer"): + param_lr = param.optimize_attr['learning_rate'] + param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param))) + grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad))) + if type(param_lr) == float and param_lr == 1.0: + decayed_lr = learning_rate * param_norm \ + / _balanced_weight(param_norm, grad_norm) + else: + decayed_lr = learning_rate * param_lr * param_norm \ + / _balanced_weight(param_norm, grad_norm) + # set back param local learning rate + param.optimize_attr['learning_rate'] = decayed_lr diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index cca618b9ad2fef9bf4870f0f94d17fbc529fb83c..4bfa89d9facf1d368e3018a248dc090c81c3402e 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -154,7 +154,9 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', + 'sequence_reverse', 'affine_channel', + 'hash', ] @@ -980,7 +982,12 @@ def cos_sim(X, Y): return out -def dropout(x, dropout_prob, is_test=False, seed=None, name=None): +def dropout(x, + dropout_prob, + is_test=False, + seed=None, + name=None, + dropout_implementation="downgrade_in_infer"): """ Computes dropout. @@ -1000,6 +1007,21 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): units will be dropped. DO NOT use a fixed seed in training. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. + dropout_implementation(string): ['downgrade_in_infer'(defauld)|'upscale_in_train'] + 1. downgrade_in_infer(default), downgrade the outcome at inference + train: out = input * mask + inference: out = input * dropout_prob + (make is a tensor same shape with input, value is 0 or 1 + ratio of 0 is dropout_prob) + 2. upscale_in_train, upscale the outcome at training time + train: out = input * mask / ( 1.0 - dropout_prob ) + inference: out = input + (make is a tensor same shape with input, value is 0 or 1 + ratio of 0 is dropout_prob) + dropout op can be removed from the program. + the program will be efficient + + Returns: Variable: A tensor variable is the shape with `x`. @@ -1029,7 +1051,8 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): 'dropout_prob': dropout_prob, 'is_test': is_test, 'fix_seed': seed is not None, - 'seed': seed if seed is not None else 0 + 'seed': seed if seed is not None else 0, + 'dropout_implementation': dropout_implementation, }) return out @@ -1969,17 +1992,17 @@ def sequence_slice(input, offset, length, name=None): """ **Sequence Slice Layer** - The layer crops a subsequence from given sequence with given start + The layer crops a subsequence from given sequence with given start offset and subsequence length. It only supports sequence data (LoDTensor with lod_level equal to 1). .. code-block:: text - + - Case: Given the input Variable **input**: - + input.data = [[a1, a2], [b1, b2], [c1, c2], [d1, d2], [e1, e2]], input.lod = [[3, 2]], input.dims = (5, 2), @@ -1987,16 +2010,16 @@ def sequence_slice(input, offset, length, name=None): with offset.data = [[0], [1]] and length.data = [[2], [1]], the output Variable will be - + out.data = [[a1, a2], [b1, b2], [e1, e2]], out.lod = [[2, 1]], out.dims = (3, 2). - - NOTE: The first dimension size of **input**, **offset** and **length** + + NOTE: The first dimension size of **input**, **offset** and **length** should be equal. The **offset** should start from 0. - + Args: - input(Variable): The input Variable which consists of the complete + input(Variable): The input Variable which consists of the complete sequences. offset(Variable): The offset to slice each sequence. length(Variable): The length of each subsequence. @@ -2015,7 +2038,7 @@ def sequence_slice(input, offset, length, name=None): dtype='float32', lod_level=1) offset = fluid.layers.assign(input=np.array([[0, 1]]).astype("int32")) length = fluid.layers.assign(input=np.array([[2, 1]]).astype("int32")) - subseqs = fluid.layers.sequence_slice(input=seqs, offset=offset, + subseqs = fluid.layers.sequence_slice(input=seqs, offset=offset, length=length) """ helper = LayerHelper("sequence_slice", **locals()) @@ -2398,12 +2421,12 @@ def layer_norm(input, param_attr(ParamAttr|None): The parameter attribute for the learnable gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is omitted. If :attr:`scale` is True and :attr:`param_attr` is None, - a default :code:`ParamAttr` would be added as scale. The - :attr:`param_attr` is initialized as 1 if it is added. Default None. + a default :code:`ParamAttr` would be added as scale. The + :attr:`param_attr` is initialized as 1 if it is added. Default None. bias_attr(ParamAttr|None): The parameter attribute for the learnable bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is omitted. If :attr:`shift` is True and :attr:`param_attr` is None, - a default :code:`ParamAttr` would be added as bias. The + a default :code:`ParamAttr` would be added as bias. The :attr:`bias_attr` is initialized as 0 if it is added. Default None. act(str): Activation to be applied to the output of layer normalizaiton. Default None. @@ -3021,8 +3044,8 @@ def sequence_unpad(x, length, name=None): """ **Sequence Unpad Layer** - This layer removes the padding data in the input sequences and convert - them into sequences with actual length as output, identitied by lod + This layer removes the padding data in the input sequences and convert + them into sequences with actual length as output, identitied by lod information. .. code-block:: text @@ -3032,9 +3055,9 @@ def sequence_unpad(x, length, name=None): Given input Variable **x**: x.data = [[ 1.0, 2.0, 3.0, 4.0, 5.0], [ 6.0, 7.0, 8.0, 9.0, 10.0], - [11.0, 12.0, 13.0, 14.0, 15.0]], - - in which there are 3 sequences padded to length 5, and the acutal length + [11.0, 12.0, 13.0, 14.0, 15.0]], + + in which there are 3 sequences padded to length 5, and the acutal length specified by input Variable **length**: length.data = [[2], [3], [4]], @@ -3042,7 +3065,7 @@ def sequence_unpad(x, length, name=None): after unpadding, the output Variable will be: out.data = [[1.0, 2.0, 6.0, 7.0, 8.0, 11.0, 12.0, 13.0, 14.0]] - out.lod = [[2, 3, 4]] + out.lod = [[2, 3, 4]] Args: x(Variable): Input Variable which contains the padded sequences with @@ -4844,7 +4867,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): return counter -def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): +def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): """ Gives a new shape to the input Tensor without changing its data. @@ -4892,15 +4915,22 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): :attr:`shape` specifying shape. That is to say :attr:`actual_shape` has a higher priority than :attr:`shape`. - act (str): The non-linear activation to be applied to output variable. - inplace(bool): If this flag is set true, the output - shares data with input without copying, otherwise - a new output tensor is created - whose data is copied from input x. + act (str): The non-linear activation to be applied to the reshaped tensor + variable. + inplace(bool): Must use :attr:`False` if :attr:`x` is used in multiple + operators. If this flag is set :attr:`True`, reuse input + :attr:`x` to reshape, which will change the shape of + tensor variable :attr:`x` and might cause errors when + :attr:`x` is used in multiple operators. If :attr:`False`, + preserve the shape :attr:`x` and create a new output tensor + variable whose data is copied from input x but reshaped. name (str): The name of this layer. It is optional. Returns: - Variable: The output tensor. + Variable: The reshaped tensor variable if :attr:`act` is None. It is a \ + new tensor variable if :attr:`inplace` is :attr:`False`, \ + otherwise it is :attr:`x`. If :attr:`act` is not None, return \ + the activated tensor variable. Raises: TypeError: if actual_shape is neither Variable nor None. @@ -4911,7 +4941,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): data = fluid.layers.data( name='data', shape=[2, 4, 6], dtype='float32') reshaped = fluid.layers.reshape( - x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True) + x=data, shape=[-1, 0, 3, 2], inplace=True) """ if not (isinstance(shape, list) or isinstance(shape, tuple)): @@ -4938,7 +4968,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): "except one unknown dimension.") helper = LayerHelper("reshape2", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) + out = x if inplace else helper.create_variable_for_type_inference( + dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type="reshape2", @@ -5469,9 +5500,9 @@ def roi_align(input, Examples: .. code-block:: python - align_out = fluid.layers.roi_align(input=x, - rois=rois, - pooled_height=7, + align_out = fluid.layers.roi_align(input=x, + rois=rois, + pooled_height=7, pooled_width=7, spatial_scale=0.5, sampling_ratio=-1) @@ -7455,13 +7486,40 @@ def maxout(x, groups, name=None): return out +@templatedoc() +def sequence_reverse(x, name=None): + """ + ${comment} + + Args: + x(${x_type}): ${x_comment} + name(basestring|None): Name of the output. + + Returns: + out(${y_type}): ${y_comment} + """ + helper = LayerHelper("sequence_reverse", **locals()) + if name is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="sequence_reverse", + inputs={"X": x}, + outputs={"Y": out}, + attrs=dict()) + return out + + def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): """ Applies a separate affine transformation to each channel of the input. Useful for replacing spatial batch norm with its equivalent fixed transformation. The input also can be 2D tensor and applies a affine transformation in second dimension. - + Args: x (Variable): Feature map input can be a 4D tensor with order NCHW or NHWC. It also can be a 2D tensor and the affine transformation @@ -7494,3 +7552,31 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): attrs={"data_layout": data_layout}, outputs={"Out": out}) return out + + +def hash(input, hash_size, num_hash=1, name=None): + """ + hash the input + Args: + input (Variable): The input variable which is a one-hot word. + hash_size (int): The space size for hash algorithm. + num_hash (int): The times of hash, default 1. + name (str, default None): The name of this layer. + Returns: + Variable: The hash result variable which is a LoDTensor. + Examples: + .. code-block:: python + word_dict = paddle.dataset.imdb.word_dict() + x = fluid.layers.data(shape[1], dtype='int32', lod_level=1) + out = fluid.layers.hash(input=x, len(word_dict)) + """ + helper = LayerHelper('hash', **locals()) + out = helper.create_variable_for_type_inference( + helper.input_dtype(), stop_gradient=True) + helper.append_op( + type='hash', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'num_hash': num_hash, + 'mod_by': hash_size}) + return out diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index 0c2800dcf35ed156b71625babea2724f520575e5..a4503e75671d7d12ff84bb538776f8e6c832b9d1 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -13,8 +13,6 @@ # limitations under the License. """ Fluid Metrics - -The metrics are accomplished via Python natively. """ from __future__ import print_function @@ -24,6 +22,12 @@ import copy import warnings import six +from .layer_helper import LayerHelper +from .initializer import Constant +from . import unique_name +from .framework import Program, Variable, program_guard +from . import layers + __all__ = [ 'MetricBase', 'CompositeMetric', @@ -474,71 +478,10 @@ class EditDistance(MetricBase): "There is no data in EditDistance Metric. Please check layers.edit_distance output has been added to EditDistance." ) avg_distance = self.total_distance / self.seq_num - avg_instance_error = self.instance_error / self.seq_num + avg_instance_error = self.instance_error / float(self.seq_num) return avg_distance, avg_instance_error -class DetectionMAP(MetricBase): - """ - Calculate the detection mean average precision (mAP). - mAP is the metric to measure the accuracy of object detectors - like Faster R-CNN, SSD, etc. - It is the average of the maximum precisions at different recall values. - Please get more information from the following articles: - https://sanchom.wordpress.com/tag/average-precision/ - - https://arxiv.org/abs/1512.02325 - - The general steps are as follows: - - 1. calculate the true positive and false positive according to the input - of detection and labels. - 2. calculate mAP value, support two versions: '11 point' and 'integral'. - - Examples: - .. code-block:: python - - pred = fluid.layers.fc(input=data, size=1000, act="tanh") - batch_map = layers.detection_map( - input, - label, - class_num, - background_label, - overlap_threshold=overlap_threshold, - evaluate_difficult=evaluate_difficult, - ap_version=ap_version) - metric = fluid.metrics.DetectionMAP() - for data in train_reader(): - loss, preds, labels = exe.run(fetch_list=[cost, batch_map]) - batch_size = data[0] - metric.update(value=batch_map, weight=batch_size) - numpy_map = metric.eval() - """ - - def __init__(self, name=None): - super(DetectionMAP, self).__init__(name) - # the current map value - self.value = .0 - self.weight = .0 - - def update(self, value, weight): - if not _is_number_or_matrix_(value): - raise ValueError( - "The 'value' must be a number(int, float) or a numpy ndarray.") - if not _is_number_(weight): - raise ValueError("The 'weight' must be a number(int, float).") - self.value += value - self.weight += weight - - def eval(self): - if self.weight == 0: - raise ValueError( - "There is no data in DetectionMAP Metrics. " - "Please check layers.detection_map output has added to DetectionMAP." - ) - return self.value / self.weight - - class Auc(MetricBase): """ Auc metric adapts to the binary classification. @@ -616,3 +559,179 @@ class Auc(MetricBase): idx -= 1 return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0 + + +class DetectionMAP(object): + """ + Calculate the detection mean average precision (mAP). + + The general steps are as follows: + 1. calculate the true positive and false positive according to the input + of detection and labels. + 2. calculate mAP value, support two versions: '11 point' and 'integral'. + + Please get more information from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 + + Args: + input (Variable): The detection results, which is a LoDTensor with shape + [M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax]. + gt_label (Variable): The ground truth label index, which is a LoDTensor + with shape [N, 1]. + gt_box (Variable): The ground truth bounding box (bbox), which is a + LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax]. + gt_difficult (Variable|None): Whether this ground truth is a difficult + bounding bbox, which can be a LoDTensor [N, 1] or not set. If None, + it means all the ground truth labels are not difficult bbox. + class_num (int): The class number. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all categories will be + considered, 0 by defalut. + overlap_threshold (float): The threshold for deciding true/false + positive, 0.5 by defalut. + evaluate_difficult (bool): Whether to consider difficult ground truth + for evaluation, True by defalut. This argument does not work when + gt_difficult is None. + ap_version (string): The average precision calculation ways, it must be + 'integral' or '11point'. Please check + https://sanchom.wordpress.com/tag/average-precision/ for details. + - 11point: the 11-point interpolated average precision. + - integral: the natural integral of the precision-recall curve. + + Examples: + .. code-block:: python + + exe = fluid.Executor(place) + map_evaluator = fluid.Evaluator.DetectionMAP(input, + gt_label, gt_box, gt_difficult) + cur_map, accum_map = map_evaluator.get_map_var() + fetch = [cost, cur_map, accum_map] + for epoch in PASS_NUM: + map_evaluator.reset(exe) + for data in batches: + loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch) + + In the above example: + + 'cur_map_v' is the mAP of current mini-batch. + 'accum_map_v' is the accumulative mAP of one pass. + """ + + def __init__(self, + input, + gt_label, + gt_box, + gt_difficult=None, + class_num=None, + background_label=0, + overlap_threshold=0.5, + evaluate_difficult=True, + ap_version='integral'): + + self.helper = LayerHelper('map_eval') + gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype) + if gt_difficult: + gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype) + label = layers.concat([gt_label, gt_difficult, gt_box], axis=1) + else: + label = layers.concat([gt_label, gt_box], axis=1) + + # calculate mean average precision (mAP) of current mini-batch + map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + ap_version=ap_version) + + states = [] + states.append( + self._create_state( + dtype='int32', shape=None, suffix='accum_pos_count')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_true_pos')) + states.append( + self._create_state( + dtype='float32', shape=None, suffix='accum_false_pos')) + var = self._create_state(dtype='int32', shape=[1], suffix='has_state') + self.helper.set_variable_initializer( + var, initializer=Constant(value=int(0))) + self.has_state = var + + # calculate accumulative mAP + accum_map = layers.detection_map( + input, + label, + class_num, + background_label, + overlap_threshold=overlap_threshold, + evaluate_difficult=evaluate_difficult, + has_state=self.has_state, + input_states=states, + out_states=states, + ap_version=ap_version) + + layers.fill_constant( + shape=self.has_state.shape, + value=1, + dtype=self.has_state.dtype, + out=self.has_state) + + self.cur_map = map + self.accum_map = accum_map + + def _create_state(self, suffix, dtype, shape): + """ + Create state variable. + Args: + suffix(str): the state suffix. + dtype(str|core.VarDesc.VarType): the state data type + shape(tuple|list): the shape of state + Returns: State variable + """ + state = self.helper.create_variable( + name="_".join([unique_name.generate(self.helper.name), suffix]), + persistable=True, + dtype=dtype, + shape=shape) + return state + + def get_map_var(self): + """ + Returns: mAP variable of current mini-batch and + accumulative mAP variable cross mini-batches. + """ + return self.cur_map, self.accum_map + + def reset(self, executor, reset_program=None): + """ + Reset metric states at the begin of each pass/user specified batch. + + Args: + executor(Executor): a executor for executing + the reset_program. + reset_program(Program|None): a single Program for reset process. + If None, will create a Program. + """ + + def _clone_var_(block, var): + assert isinstance(var, Variable) + return block.create_var( + name=var.name, + shape=var.shape, + dtype=var.dtype, + type=var.type, + lod_level=var.lod_level, + persistable=var.persistable) + + if reset_program is None: + reset_program = Program() + with program_guard(main_program=reset_program): + var = _clone_var_(reset_program.current_block(), self.has_state) + layers.fill_constant( + shape=var.shape, value=0, dtype=var.dtype, out=var) + executor.run(reset_program) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 17af44afdde5cdbec082d473457ef01974695bc6..7e2364a5a872cdd8cf590438cc081ab070db767d 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -14,6 +14,7 @@ from __future__ import print_function import re +import sys from collections import defaultdict from paddle.fluid.framework import Program, Variable, name_scope, default_main_program from . import framework @@ -32,7 +33,8 @@ __all__ = [ 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', - 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer' + 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum', + 'LarsMomentumOptimizer' ] @@ -105,13 +107,14 @@ class Optimizer(object): param = param_and_grad[0] param_lr = param.optimize_attr['learning_rate'] if type(param_lr) == Variable: - print("returns updated param lr ", param_lr) return param_lr else: if param_lr == 1.0: return self._global_learning_rate() else: - with default_main_program()._lr_schedule_guard(): + with default_main_program()._lr_schedule_guard( + is_with_opt=True), framework.name_scope( + 'scale_with_param_lr'): return self._global_learning_rate() * param_lr def _create_accumulators(self, block, parameters): @@ -398,6 +401,91 @@ class MomentumOptimizer(Optimizer): return momentum_op +class LarsMomentumOptimizer(Optimizer): + """ + Momentum optimizer with LARS support + + The update equations are as follows: + + .. math:: + + & local\_learning\_rate = learning\_rate * lars\_coeff * \\ + \\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||} + + & velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param) + + & param = param - velocity + + Args: + learning_rate (float|Variable): the learning rate used to update parameters. \ + Can be a float value or a Variable with one float value as data element. + momentum (float): momentum factor + lars_coeff (float): defines how much we trust the layer to change its weights. + lars_weight_decay (float): weight decay coefficient for decaying using LARS. + regularization: A Regularizer, such as + fluid.regularizer.L2DecayRegularizer. + name: A optional name prefix. + + + Examples: + .. code-block:: python + + optimizer = fluid.optimizer.LarsMomentum(learning_rate=0.2, momentum=0.1, lars_weight_decay=0.001) + optimizer.minimize(cost) + """ + _velocity_acc_str = "velocity" + + def __init__(self, + learning_rate, + momentum, + lars_coeff=0.001, + lars_weight_decay=0.0005, + regularization=None, + name=None): + assert learning_rate is not None + assert momentum is not None + super(LarsMomentumOptimizer, self).__init__( + learning_rate=learning_rate, + regularization=regularization, + name=name) + self.type = "lars_momentum" + self._momentum = momentum + self._lars_coeff = float(lars_coeff) + self._lars_weight_decay = float(lars_weight_decay) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Velocity": velocity_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={ + "ParamOut": param_and_grad[0], + "VelocityOut": velocity_acc + }, + attrs={ + "mu": self._momentum, + "lars_coeff": self._lars_coeff, + "lars_weight_decay": self._lars_weight_decay + }) + + return momentum_op + + class AdagradOptimizer(Optimizer): """ **Adaptive Gradient Algorithm (Adagrad)** @@ -602,7 +690,8 @@ class AdamOptimizer(Optimizer): for param, grad in param_and_grads: if grad is None: continue - with param.block.program._optimized_guard([param, grad]): + with param.block.program._optimized_guard( + [param, grad]), name_scope("optimizer"): beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param) beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, @@ -740,7 +829,8 @@ class AdamaxOptimizer(Optimizer): for param, grad in parameters_and_grads: if grad is None: continue - with param.block.program._optimized_guard([param, grad]): + with param.block.program._optimized_guard( + [param, grad]), name_scope('adamx'): beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, param) main_block.append_op( @@ -1217,6 +1307,7 @@ DecayedAdagrad = DecayedAdagradOptimizer Adadelta = AdadeltaOptimizer RMSProp = RMSPropOptimizer Ftrl = FtrlOptimizer +LarsMomentum = LarsMomentumOptimizer class ModelAverage(Optimizer): @@ -1279,7 +1370,8 @@ class ModelAverage(Optimizer): for param, grad in self.params_grads: if grad is None: continue - with param.block.program._optimized_guard([param, grad]): + with param.block.program._optimized_guard( + [param, grad]), name_scope('move_average'): self._append_average_accumulate_op(param) self.apply_program = Program() diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index c151fbd17208bb6e3104e8d0f6590392c6095987..57185da4d1d38f3848994aae105411cf2844843a 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -47,7 +47,8 @@ def append_regularization_ops(parameters_and_grads, regularization=None): if grad is None: params_and_grads.append((param, grad)) continue - with param.block.program._optimized_guard([param, grad]): + with param.block.program._optimized_guard( + [param, grad]), framework.name_scope('regularization'): regularization_term = None if param.regularizer is not None: # Add variable for regularization term in grad block diff --git a/python/paddle/fluid/tests/CMakeLists.txt b/python/paddle/fluid/tests/CMakeLists.txt index d6568cd38e714bf9eb9d34da8a1c6a5cdb6677e3..7ad923d3321ec8a88b60d7f4f7777e12fad8faa6 100644 --- a/python/paddle/fluid/tests/CMakeLists.txt +++ b/python/paddle/fluid/tests/CMakeLists.txt @@ -1,8 +1,4 @@ -if(NOT APPLE) - set(PYTHON_TESTS_DIR ${CMAKE_CURRENT_BINARY_DIR} CACHE PATH "python tests directory") -else() - set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) -endif(NOT APPLE) +set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 56129641ce5900d82aedf243d2fa1eadfd6b8d86..28dc7519571d8b5464e92fddf634ba46691ceaa9 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase): dtype='float32', lod_level=1, append_batch_size=False) - pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( + pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign( bbox_pred=bbox_pred, cls_logits=cls_logits, anchor_box=anchor_box, @@ -313,15 +313,18 @@ class TestRpnTargetAssign(unittest.TestCase): rpn_straddle_thresh=0.0, rpn_fg_fraction=0.5, rpn_positive_overlap=0.7, - rpn_negative_overlap=0.3) + rpn_negative_overlap=0.3, + use_random=False) self.assertIsNotNone(pred_scores) self.assertIsNotNone(pred_loc) self.assertIsNotNone(tgt_lbl) self.assertIsNotNone(tgt_bbox) + self.assertIsNotNone(bbox_inside_weight) assert pred_scores.shape[1] == 1 assert pred_loc.shape[1] == 4 assert pred_loc.shape[1] == tgt_bbox.shape[1] + print(str(program)) class TestGenerateProposals(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7de0ebce06e9de439d3570bee9ac7dbce33ee868..cf54bc2dbe788f3757a7ef93f26156d118a0cd02 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) - # TODO: fix this test - #py_test_modules(test_dist_transformer MODULES test_dist_transformer) - #set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) + # FIXME(typhoonzero): add this back + #py_test_modules(test_dist_transformer MODULES test_dist_transformer) + #set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) endif(NOT APPLE) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) endif() diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 877d21ae882ab4efb49beb6a846ab71a22c2aab7..01e9795d8b1beb67270f45fe7ba2819bf8c3be3e 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -95,7 +95,7 @@ class TestDistMnist2x2(TestDistRunnerBase): # Reader train_reader = paddle.batch( - paddle.dataset.mnist.train(), batch_size=batch_size) + paddle.dataset.mnist.test(), batch_size=batch_size) test_reader = paddle.batch( paddle.dataset.mnist.test(), batch_size=batch_size) opt.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..d386e75fd887a898f5a13e48e378e08ff6c99ea0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_batch_merge.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + +DTYPE = "float32" + + +def test_merge_reader(repeat_batch_size=8): + orig_reader = paddle.dataset.mnist.test() + record_batch = [] + b = 0 + for d in orig_reader(): + if b >= repeat_batch_size: + break + record_batch.append(d) + b += 1 + while True: + for d in record_batch: + yield d + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) + + # Reader + train_reader = paddle.batch(test_merge_reader, batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist_lars.py b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py new file mode 100644 index 0000000000000000000000000000000000000000..977e17c37f7676ae81d9ab29b6b36089ccbeeacf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_mnist_lars.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +from dist_mnist import cnn_model + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +class TestDistMnist2x2(TestDistRunnerBase): + def get_model(self, batch_size=2): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + inference_program = fluid.default_main_program().clone() + # Optimization + opt = fluid.optimizer.LarsMomentumOptimizer( + learning_rate=0.001, momentum=0.9) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + opt.minimize(avg_cost) + return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestDistMnist2x2) diff --git a/python/paddle/fluid/tests/unittests/dist_transformer.py b/python/paddle/fluid/tests/unittests/dist_transformer.py index a2cc57425841100a2b61279d1b447b88ed4b9a54..27c67edf4f62dd3c5d396826348f8da4513667ba 100644 --- a/python/paddle/fluid/tests/unittests/dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/dist_transformer.py @@ -35,7 +35,7 @@ import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers from paddle.fluid import core -from test_dist_base import TestDistRunnerBase, runtime_main +from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP import paddle.compat as cpt from paddle.compat import long_type @@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, for pass_id in six.moves.xrange(TrainTaskConfig.pass_num): pass_start_time = time.time() for batch_id, data in enumerate(train_data()): - if batch_id >= 5: + if batch_id >= RUN_STEP: break feed_list = [] total_num_token = 0 - #if TrainTaskConfig.local: - # lr_rate = lr_scheduler.update_learning_rate() - #for place_id, data_buffer in enumerate( - # split_data( - # data, num_part=dev_count)): - if TrainTaskConfig.local: lr_rate = lr_scheduler.update_learning_rate() @@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, init = True # Validate and save the model for inference. - if batch_id == 0 or batch_id == 4: - if TrainTaskConfig.val_file_pattern is not None: - val_avg_cost, val_ppl = test() - print("[%f]" % val_avg_cost) - else: - assert (False) + if TrainTaskConfig.val_file_pattern is not None: + val_avg_cost, val_ppl = test() + print("[%f]" % val_avg_cost) + else: + assert (False) #import transformer_reader as reader @@ -1166,6 +1159,7 @@ def prepare_encoder(src_word, name=pos_enc_param_name, trainable=False, initializer=fluid.initializer.ConstantInitializer(0.001))) + src_pos_enc.stop_gradient = True enc_input = src_word_emb + src_pos_enc return layers.dropout( enc_input, @@ -1701,7 +1695,7 @@ class DistTransformer2x2(TestDistRunnerBase): def run_trainer(self, args): TrainTaskConfig.use_gpu = args.use_cuda - sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( + sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model( args.is_dist, not args.sync_mode) if args.is_dist: diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 04924bec057e301bfb342a62bb4c1e0b3c3aff4c..87fd03ca61d33a53b9323edb2ec7e1c71655816b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -26,10 +26,11 @@ import argparse import paddle.fluid as fluid RUN_STEP = 10 +DEFAULT_BATCH_SIZE = 2 class TestDistRunnerBase(object): - def get_model(self, batch_size=2): + def get_model(self, batch_size=DEFAULT_BATCH_SIZE): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -48,8 +49,7 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): - - self.get_model(batch_size=2) + self.get_model(batch_size=args.batch_size) # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, fluid.default_main_program(), args.endpoints, @@ -65,7 +65,7 @@ class TestDistRunnerBase(object): def run_trainer(self, args): test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ - self.get_model(batch_size=2) + self.get_model(batch_size=args.batch_size) if args.mem_opt: fluid.memory_optimize(fluid.default_main_program(), skip_grads=True) @@ -92,6 +92,11 @@ class TestDistRunnerBase(object): strategy.allow_op_delay = False build_stra = fluid.BuildStrategy() + if args.batch_merge_repeat > 1: + pass_builder = build_stra._create_passes_from_strategy() + mypass = pass_builder.insert_pass( + len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") + mypass.set_int("num_repeats", args.batch_merge_repeat) if args.use_reduce: build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce @@ -145,6 +150,9 @@ def runtime_main(test_class): parser.add_argument('--use_reduce', action='store_true') parser.add_argument( '--use_reader_alloc', action='store_true', required=False, default=True) + parser.add_argument('--batch_size', required=False, type=int, default=2) + parser.add_argument( + '--batch_merge_repeat', required=False, type=int, default=1) args = parser.parse_args() @@ -244,9 +252,18 @@ class TestDistBase(unittest.TestCase): (e, retry_times)) retry_times -= 1 - def _run_local(self, model, envs, check_error_log): + def _run_local(self, + model, + envs, + check_error_log=False, + batch_size=DEFAULT_BATCH_SIZE, + batch_merge_repeat=1): cmd = "%s %s --role trainer" % (self._python_interp, model) + if batch_size != DEFAULT_BATCH_SIZE: + cmd += " --batch_size %d" % batch_size + if batch_merge_repeat > 1: + cmd += " --batch_merge_repeat %d" % batch_merge_repeat if self.__use_cuda: cmd += " --use_cuda" diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index 3575fd07fc727bd6c6b07a19a60b1df6656ae9e2..390393e04f8a1ff7b994da66cf1fa104ccb61793 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -23,9 +23,8 @@ class TestDistCTR2x2(TestDistBase): self._sync_mode = True self._enforce_place = "CPU" - -def test_dist_ctr(self): - self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) + def test_dist_ctr(self): + self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index f65dd7e2a28c4ace3988c0cc1267ebe981fbd9cb..922dd838f8996adfc15afffcd44c1acca2bc14a9 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -26,6 +26,15 @@ class TestDistMnist2x2(TestDistBase): self.check_with_place("dist_mnist.py", delta=1e-5) +class TestDistMnist2x2Lars(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + + def test_se_resnext(self): + self.check_with_place("dist_mnist_lars.py", delta=1e-5) + + class TestDistMnist2x2WithMemopt(TestDistBase): def _setup_config(self): self._sync_mode = True diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..22d4b7929033529c5cea60064e6d9de57eddeb8e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_batch_merge.py @@ -0,0 +1,67 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import os + + +class TestDistMnist2x2(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + + def test_dist_train(self): + self.check_with_place("dist_mnist_batch_merge.py", delta=1e-5) + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + # TODO(typhoonzero): should auto adapt GPU count on the machine. + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_cudnn_deterministic": "1", + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "7" + required_envs["GLOG_logtostderr"] = "1" + + no_merge_losses = self._run_local( + model_file, + required_envs, + check_error_log=check_error_log, + batch_size=4) + + batch_merge_losses = self._run_local( + model_file, + required_envs, + check_error_log=check_error_log, + batch_size=2, + batch_merge_repeat=2) + # Ensure both result have values. + self.assertGreater(len(no_merge_losses), 1) + self.assertEqual(len(no_merge_losses), len(batch_merge_losses)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py index a0b6879f99e80a9710ee76f981769299a066b85b..fcf793da07eb8985fb19c9e36ff9b7d5a8f51212 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py +++ b/python/paddle/fluid/tests/unittests/test_dist_simnet_bow.py @@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): self._sync_mode = False self._enforce_place = "CPU" - def test_simnet_bow(self): + #FIXME(typhoonzero): fix async tests later + def no_test_simnet_bow(self): need_envs = { "IS_DISTRIBUTED": '0', "IS_SPARSE": '0', diff --git a/python/paddle/fluid/tests/unittests/test_dist_transformer.py b/python/paddle/fluid/tests/unittests/test_dist_transformer.py index 47e8dfaf03ceb27a74f5e48d662d2b534d2d152b..25dcccc28d710695d4c5e08c17816669d0fae5d8 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transformer.py @@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase): def test_dist_train(self): download_files() - self.check_with_place("dist_transformer.py", delta=1e-5) + self.check_with_place( + "dist_transformer.py", delta=1e-5, check_error_log=False) class TestDistTransformer2x2Async(TestDistBase): @@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase): def test_dist_train(self): download_files() - self.check_with_place("dist_transformer.py", delta=1.0) + self.check_with_place( + "dist_transformer.py", delta=1.0, check_error_log=False) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 0296bc2af4e0b79478c34b4cceab32b5a8a50f2f..be3c5f3b9558ec522803ed9a5acedea75cda6ccc 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -85,6 +85,69 @@ class TestDropoutOp5(OpTest): self.check_output() +class TestDropoutOp6(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = { + 'dropout_prob': 1.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': np.zeros((32, 64)).astype('float32'), + 'Mask': np.zeros((32, 64)).astype('float32') + } + + +class TestDropoutOp7(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64, 2)).astype('float32') + } + + +class TestDropoutOp8(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.35, + 'fix_seed': True, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp9(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = { + 'dropout_prob': 0.75, + 'is_test': True, + 'dropout_implementation': 'upscale_in_train' + } + self.outputs = {'Out': self.inputs['X']} + + def test_check_output(self): + self.check_output() + + class TestFP16DropoutOp(OpTest): def setUp(self): self.op_type = "dropout" diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py index 36ebc8fb6ea9efdcd1807f5c8917ab1428b3381e..377454e7802e40f90c371987adfe50cce922c764 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py @@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): self.D = 8 +class TestFusionGRUOpMD3(TestFusionGRUOp): + def set_confs(self): + self.M = 17 + self.D = 15 + + class TestFusionGRUOpBS1(TestFusionGRUOp): def set_confs(self): self.lod = [[3]] diff --git a/python/paddle/fluid/tests/unittests/test_hash_op.py b/python/paddle/fluid/tests/unittests/test_hash_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1130ea39c42204283885ab1072a52db8c22f8b2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_hash_op.py @@ -0,0 +1,57 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestScaleOp(OpTest): + def setUp(self): + self.op_type = "hash" + self.init_test_case() + self.inputs = {'X': (self.in_seq, self.lod)} + self.attrs = {'num_hash': 4, 'mod_by': 10000} + self.outputs = {'Out': (self.out_seq, self.lod)} + + def init_test_case(self): + np.random.seed = 1 + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + # self.out_seq = np.ones([30, 4, 1], dtype=np.int32) + self.out_seq = [ + [[9662], [9217], [1129], [8487]], [[9662], [9217], [1129], [8487]], + [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], + [[9407], [6715], [6949], [8094]], [[8473], [694], [5142], [2479]], + [[8310], [1327], [1654], [4567]], [[6897], [3218], [2013], [1241]], + [[4372], [9456], [8204], [6695]], [[6897], [3218], [2013], [1241]], + [[8473], [694], [5142], [2479]], [[4372], [9456], [8204], [6695]], + [[4372], [9456], [8204], [6695]], [[8473], [694], [5142], [2479]], + [[9407], [6715], [6949], [8094]], [[9369], [4525], [8935], [9210]], + [[4372], [9456], [8204], [6695]], [[4372], [9456], [8204], [6695]], + [[9369], [4525], [8935], [9210]], [[6897], [3218], [2013], [1241]], + [[9038], [7951], [5953], [8657]], [[9407], [6715], [6949], [8094]], + [[9662], [9217], [1129], [8487]], [[9369], [4525], [8935], [9210]], + [[9038], [7951], [5953], [8657]], [[9662], [9217], [1129], [8487]], + [[9369], [4525], [8935], [9210]], [[1719], [5986], [9919], [3421]], + [[4372], [9456], [8204], [6695]], [[9038], [7951], [5953], [8657]] + ] + self.out_seq = np.array(self.out_seq) + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_metrics.py b/python/paddle/fluid/tests/unittests/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..ec27884cae2b0462951f6597b1b83e58d1c8af5d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_metrics.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +from paddle.fluid.framework import Program, program_guard + + +class TestMetricsDetectionMap(unittest.TestCase): + def test_detection_map(self): + program = fluid.Program() + with program_guard(program): + detect_res = fluid.layers.data( + name='detect_res', + shape=[10, 6], + append_batch_size=False, + dtype='float32') + label = fluid.layers.data( + name='label', + shape=[10, 1], + append_batch_size=False, + dtype='float32') + box = fluid.layers.data( + name='bbox', + shape=[10, 4], + append_batch_size=False, + dtype='float32') + map_eval = fluid.metrics.DetectionMAP( + detect_res, label, box, class_num=21) + cur_map, accm_map = map_eval.get_map_var() + self.assertIsNotNone(cur_map) + self.assertIsNotNone(accm_map) + print(str(program)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index a3d89610b40ff9bd5002e843f8667ada87e67981..cf4346cf2e7a099334ec273546901a91d0ad925d 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -90,6 +90,45 @@ class TestMomentumOp2(OpTest): self.check_output() +class TestLarsMomentumOp(OpTest): + def setUp(self): + self.op_type = "lars_momentum" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + mu = 0.0001 + lars_coeff = 0.001 + lars_weight_decay = 0.0005 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Velocity': velocity, + 'LearningRate': learning_rate + } + + self.attrs = { + 'mu': mu, + 'lars_coeff': lars_coeff, + 'lars_weight_decay': lars_weight_decay + } + + pnorm = np.sqrt(np.square(param).sum()) + gnorm = np.sqrt(np.square(grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * param) + velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay * + param) + param_out = param - velocity_out + + self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + + def test_check_output(self): + self.check_output() + + class TestSparseMomentumOp(unittest.TestCase): def setUp(self): self.use_nesterov = False diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py index f63dbcd3d7f6bfce3ccc1c42ae41afe42bfad003..1a2c9bb5f43d55d8e6183de0d55bfcc2b9ac3f08 100644 --- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py @@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap, fg_inds, size=(len(fg_inds) - num_fg), replace=False) else: disable_inds = fg_inds[num_fg:] + labels[disable_inds] = -1 fg_inds = np.where(labels == 1)[0] + bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32) num_bg = rpn_batch_size_per_im - np.sum(labels == 1) bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] @@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap, enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] else: enable_inds = bg_inds[:num_bg] + + fg_fake_inds = np.array([], np.int32) + fg_value = np.array([fg_inds[0]], np.int32) + fake_num = 0 + for bg_id in enable_inds: + if bg_id in fg_inds: + fake_num += 1 + fg_fake_inds = np.hstack([fg_fake_inds, fg_value]) labels[enable_inds] = 0 + + bbox_inside_weight[fake_num:, :] = 1 fg_inds = np.where(labels == 1)[0] bg_inds = np.where(labels == 0)[0] - - loc_index = fg_inds - score_index = np.hstack((fg_inds, bg_inds)) + loc_index = np.hstack([fg_fake_inds, fg_inds]) + score_index = np.hstack([fg_inds, bg_inds]) labels = labels[score_index] assert not np.any(labels == -1), "Wrong labels with -1" - gt_inds = anchor_to_gt_argmax[fg_inds] + gt_inds = anchor_to_gt_argmax[loc_index] - return loc_index, score_index, labels, gt_inds + return loc_index, score_index, labels, gt_inds, bbox_inside_weight def get_anchor(n, c, h, w): @@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors, gt_boxes_slice = gt_boxes_slice[not_crowd_inds] iou = _bbox_overlaps(inside_anchors, gt_boxes_slice) - loc_inds, score_inds, labels, gt_inds = rpn_target_assign( - iou, rpn_batch_size_per_im, rpn_positive_overlap, - rpn_negative_overlap, rpn_fg_fraction, use_random) + loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = \ + rpn_target_assign(iou, rpn_batch_size_per_im, + rpn_positive_overlap, + rpn_negative_overlap, + rpn_fg_fraction, + use_random) # unmap to all anchor loc_inds = inds_inside[loc_inds] score_inds = inds_inside[score_inds] @@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors, score_indexes = score_inds tgt_labels = labels tgt_bboxes = box_deltas + bbox_inside_weights = bbox_inside_weight else: loc_indexes = np.concatenate( [loc_indexes, loc_inds + i * anchor_num]) @@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors, [score_indexes, score_inds + i * anchor_num]) tgt_labels = np.concatenate([tgt_labels, labels]) tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) + bbox_inside_weights = np.vstack([bbox_inside_weights, \ + bbox_inside_weight]) - return loc_indexes, score_indexes, tgt_bboxes, tgt_labels + return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights class TestRpnTargetAssignOp(OpTest): @@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest): rpn_fg_fraction = 0.5 use_random = False - loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python( - all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh, - rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap, - rpn_fg_fraction, use_random) + loc_index, score_index, tgt_bbox, labels, bbox_inside_weights = \ + rpn_target_assign_in_python(all_anchors, gt_boxes, is_crowd, + im_info, lod, rpn_straddle_thresh, + rpn_batch_size_per_im, rpn_positive_overlap, + rpn_negative_overlap, + rpn_fg_fraction, use_random) labels = labels[:, np.newaxis] self.op_type = "rpn_target_assign" @@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest): 'LocationIndex': loc_index.astype('int32'), 'ScoreIndex': score_index.astype('int32'), 'TargetBBox': tgt_bbox.astype('float32'), - 'TargetLabel': labels.astype('int32') + 'TargetLabel': labels.astype('int32'), + 'BBoxInsideWeight': bbox_inside_weights.astype('float32') } def test_check_output(self): diff --git a/python/paddle/fluid/tests/unittests/test_sequence_reverse.py b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py new file mode 100644 index 0000000000000000000000000000000000000000..eebd25e0975f1711ea86093f007212cadc6334f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_reverse.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +from op_test import OpTest +import numpy as np + + +class TestSequenceReverseBase(OpTest): + def initParameters(self): + pass + + def setUp(self): + self.size = (10, 3, 4) + self.lod = [2, 3, 5] + self.dtype = 'float32' + self.initParameters() + self.op_type = 'sequence_reverse' + self.x = np.random.random(self.size).astype(self.dtype) + self.y = self.get_output() + + self.inputs = {'X': (self.x, [self.lod, ]), } + self.outputs = {'Y': (self.y, [self.lod, ]), } + + def get_output(self): + tmp_x = np.reshape(self.x, newshape=[self.x.shape[0], -1]) + tmp_y = np.ndarray(tmp_x.shape).astype(self.dtype) + prev_idx = 0 + for cur_len in self.lod: + idx_range = range(prev_idx, prev_idx + cur_len) + tmp_y[idx_range, :] = np.flip(tmp_x[idx_range, :], 0) + prev_idx += cur_len + + return np.reshape(tmp_y, newshape=self.x.shape).astype(self.dtype) + + def test_output(self): + self.check_output(0) + + def test_grad(self): + self.check_grad(['X'], 'Y') + + +class TestSequenceReserve1(TestSequenceReverseBase): + def initParameters(self): + self.size = (12, 10) + self.lod = [4, 5, 3] + + +class TestSequenceReverse2(TestSequenceReverseBase): + def initParameters(self): + self.size = (12, 10) + self.lod = [12] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index e54e170f7f1e03db4b63db72edb7395d18130f68..69b29db83a43d18c0825b610642009a0377b9901 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -21,22 +21,27 @@ from op_test import OpTest class TestTopkOp(OpTest): def setUp(self): + self.set_args() self.op_type = "top_k" - k = 1 - input = np.random.random((32, 84)).astype("float32") - output = np.ndarray((32, k)) - indices = np.ndarray((32, k)).astype("int64") + k = self.top_k + input = np.random.random((self.row, k)).astype("float32") + output = np.ndarray((self.row, k)) + indices = np.ndarray((self.row, k)).astype("int64") self.inputs = {'X': input} self.attrs = {'k': k} - for rowid in range(32): + for rowid in range(self.row): row = input[rowid] - output[rowid] = np.sort(row)[-k:] - indices[rowid] = row.argsort()[-k:] + output[rowid] = np.sort(row)[::-1][:k] + indices[rowid] = row.argsort()[::-1][:k] self.outputs = {'Out': output, 'Indices': indices} + def set_args(self): + self.row = 32 + self.top_k = 1 + def test_check_output(self): self.check_output() @@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest): output = np.ndarray((64, k)) indices = np.ndarray((64, k)).astype("int64") - # FIXME: should use 'X': input for a 3d input - self.inputs = {'X': input_flat_2d} + self.inputs = {'X': input} self.attrs = {'k': k} for rowid in range(64): row = input_flat_2d[rowid] - output[rowid] = np.sort(row)[-k:] - indices[rowid] = row.argsort()[-k:] + output[rowid] = np.sort(row)[::-1][:k] + indices[rowid] = row.argsort()[::-1][:k] + + self.outputs = { + 'Out': output.reshape((32, 2, k)), + 'Indices': indices.reshape((32, 2, k)) + } + + def test_check_output(self): + self.check_output() + + +class TestTopkOp2(OpTest): + def setUp(self): + self.op_type = "top_k" + k = 1 + m = 2056 + input = np.random.random((m, 84)).astype("float32") + output = np.ndarray((m, k)) + indices = np.ndarray((m, k)).astype("int64") + + self.inputs = {'X': input} + self.attrs = {'k': k} + + for rowid in range(m): + row = input[rowid] + output[rowid] = -np.sort(-row)[:k] + indices[rowid] = (-row).argsort()[:k] self.outputs = {'Out': output, 'Indices': indices} @@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest): self.check_output() +class TestTopkOp3(TestTopkOp): + def set_args(self): + self.row = 2056 + self.top_k = 3 + + +class TestTopkOp4(TestTopkOp): + def set_args(self): + self.row = 40000 + self.top_k = 1 + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2192139f8d5950286691a77333dd8ec35505b033..28ad8443673492b31a6228bc85939549749541e9 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -49,6 +49,7 @@ LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) +OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched @@ -1430,7 +1431,7 @@ to transpile() call.") elif op_type == "adamax": if varkey in ["Moment", "InfNorm"]: return param_shape - elif op_type == "momentum": + elif op_type in ["momentum", "lars_momentum"]: if varkey == "Velocity": return param_shape elif op_type == "rmsprop": @@ -1441,6 +1442,10 @@ to transpile() call.") return param_shape elif op_type == "sgd": pass + else: + raise ValueError( + "Not supported optimizer for distributed training: %s" % + op_type) return orig_shape def _get_varname_parts(self, varname): @@ -1717,8 +1722,10 @@ to transpile() call.") lr_ops = [] block = self.origin_program.global_block() for op in block.ops: - if int(op.attr(RPC_OP_ROLE_ATTR_NAME)) == int( - LR_SCHED_OP_ROLE_ATTR_VALUE): + role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) + if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ + role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ + int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) log("append lr op: ", op.type) return lr_ops