diff --git a/.gitignore b/.gitignore index b92bb9cc129659fa502b4a9b55548992412e5429..90138f996cf9cacc3c1cbff0cf2600eefca3f305 100644 --- a/.gitignore +++ b/.gitignore @@ -25,5 +25,6 @@ third_party/ bazel-* third_party/ +build_* # clion workspace. cmake-build-* diff --git a/CMakeLists.txt b/CMakeLists.txt index d43df124bdee2d568a0c09d5acd35d5ff96f4654..df00e977ebb547980e69ee421779c57717d771a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,7 @@ option(WITH_INFERENCE "Compile fluid inference library" ON) option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface" OFF) option(WITH_SYSTEM_BLAS "Use system blas library" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) +option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON) # PY_VERSION if(NOT PY_VERSION) diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 6ed51c648478efb9784d0c43b169c285e740e0f3..24de8d9d7ced5f8111cc5d65f761b7506bde048e 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -40,7 +40,7 @@ set(OPENBLAS_LIB_SEARCH_PATHS /usr/local/opt/openblas/lib) find_path(OPENBLAS_INC_DIR NAMES cblas.h - PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) + PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS} NO_DEFAULT_PATH) find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) find_library(OPENBLAS_LIB NAMES openblas diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 03c73786a6c31868b1893bfcb319e43e37db1a3d..f507bb41a1103c093e9569176ee868cfaac6bf7b 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -175,7 +175,10 @@ list(APPEND CUDA_NVCC_FLAGS "-std=c++11") list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") endif(NOT WIN32) -list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") +if(WITH_FAST_MATH) + # Make use of fast math library. https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html + list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") +endif() # in cuda9, suppress cuda warning on eigen list(APPEND CUDA_NVCC_FLAGS "-w") # Set :expt-relaxed-constexpr to suppress Eigen warnings diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index e029300eee9b99582f085f6b650e03f7dacc091a..573ad5e5f06a93f38f24c6a8af3b45767e93a1a4 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -3,6 +3,14 @@ INCLUDE(ExternalProject) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) +if(NOT WITH_FAST_MATH) + # EIGEN_FAST_MATH: https://eigen.tuxfamily.org/dox/TopicPreprocessorDirectives.html + # enables some optimizations which might affect the accuracy of the result. + # This currently enables the SSE vectorization of sin() and cos(), + # and speedups sqrt() for single precision. + # Defined to 1 by default. Define it to 0 to disable. + add_definitions(-DEIGEN_FAST_MATH=0) +endif() if(WITH_AMD_GPU) ExternalProject_Add( diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index c3fbe4dbdb28f1008bb274ee18293db348bfc6ed..755dbd610c40c2d9b85d3017b6f000a869b0f39a 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -27,7 +27,7 @@ IF(NOT ${CBLAS_FOUND}) SET(CBLAS_SOURCES_DIR ${THIRD_PARTY_PATH}/openblas) SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas) - SET(CBLAS_INCLUDE_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) + SET(CBLAS_INC_DIR "${CBLAS_INSTALL_DIR}/include" CACHE PATH "openblas include directory." FORCE) SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}openblas${CMAKE_STATIC_LIBRARY_SUFFIX}" @@ -96,7 +96,7 @@ IF(NOT ${CBLAS_FOUND}) ENDIF(NOT WIN32) SET(CBLAS_PROVIDER openblas) IF(WITH_C_API) - INSTALL(DIRECTORY ${CBLAS_INCLUDE_DIR} DESTINATION third_party/openblas) + INSTALL(DIRECTORY ${CBLAS_INC_DIR} DESTINATION third_party/openblas) # Because libopenblas.a is a symbolic link of another library, thus need to # install the whole directory. IF(ANDROID) @@ -117,8 +117,8 @@ IF(NOT ${CBLAS_FOUND}) ENDIF(NOT ${CBLAS_FOUND}) MESSAGE(STATUS "BLAS library: ${CBLAS_LIBRARIES}") -MESSAGE(STATUS "BLAS Include: ${CBLAS_INCLUDE_DIR}") -INCLUDE_DIRECTORIES(${CBLAS_INCLUDE_DIR}) +MESSAGE(STATUS "BLAS Include: ${CBLAS_INC_DIR}") +INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) # FIXME(gangliao): generate cblas target to track all high performance # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index 331b2af367bdf261ffbf96fb88f61cc6958ee647..343e44ab4bc21c1a656048b675062f1b897bbc77 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -157,6 +157,8 @@ if (APPLE) # On Mac OS X build fat binaries with x86_64 architectures by default. set (CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architectures for OSX" FORCE) endif() + # On Mac OS X register class specifier is deprecated and will cause warning error on latest clang 10.0 + set (COMMON_FLAGS -Wno-deprecated-register) endif(APPLE) if(LINUX) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6418da2a7e51c51575ff56aeabedff5452458fbc..c6dd919a93d119723b389d3a695f0af82d711a06 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -198,6 +198,9 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.has_inf ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.has_nan ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.isfinite ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index de960dba8f79b7efb1d6948ef9ec647ac8530c84..844291140602a7a0aac9d9d40256deaf9d8a4c60 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,3 +1,4 @@ + # windows treat symbolic file as a real file, which is different with unix # We create a hidden file and compile it instead of origin source file. function(windows_symbolic TARGET) @@ -9,11 +10,23 @@ function(windows_symbolic TARGET) if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc OR NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cu) message(FATAL " ${src}.cc and ${src}.cu must exsits, and ${src}.cu must be symbolic file.") endif() - add_custom_command(OUTPUT .${src}.cu + + # only copy the xx.cu to .xx.cu when the content are modified + set(copy_flag 1) + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu) + file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc SOURCE_STR) + file(READ ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu TARGET_STR) + if (SOURCE_STR STREQUAL TARGET_STR) + set(copy_flag 0) + endif() + endif() + if (copy_flag) + add_custom_command(OUTPUT .${src}.cu COMMAND ${CMAKE_COMMAND} -E remove ${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/${src}.cc" "${CMAKE_CURRENT_SOURCE_DIR}/.${src}.cu" COMMENT "create hidden file of ${src}.cu") - add_custom_target(${TARGET} ALL DEPENDS .${src}.cu) + endif(copy_flag) + add_custom_target(${TARGET} ALL DEPENDS .${src}.cu) endforeach() endfunction() @@ -81,6 +94,8 @@ nv_test(data_device_transform_test SRCS data_device_transform_test.cu if(WITH_GPU) if (WIN32) + # windows treat symbolic file as a real file, which is different with unix + # We create a hidden file and compile it instead of origin source file. windows_symbolic(hidden_file SRCS data_type_transform.cu) nv_library(data_type_transform SRCS .data_type_transform.cu DEPS tensor) add_dependencies(data_type_transform hidden_file) @@ -149,7 +164,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) - cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass elementwise_add_op) + cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() if (NOT WIN32) diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 8ad2fb5f3ffd9641932bbbb024a31e81d31dc9bb..d5be43b33edab7871e1bba930a4fc6cd1e293825 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/enforce.h" - #include "paddle/fluid/platform/float16.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0076a8bece31f9a977b375717c25688fc0c95819..796ce1f91ce6f3e21dc6f0af8fca4960d43f6e2b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -38,6 +38,7 @@ pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) +pass_library(conv_bn_fuse_pass inference) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..95d7138381baec17d4969ef8b5287a1e70f2ac81 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -0,0 +1,327 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_CONV_BN_NODES(pattern_name) \ + /* OPERATORS */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \ + /* CONV inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \ + /* CONV outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \ + /* BN inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \ + /* BN outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \ + GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name) + +template +LoDTensor tensor_apply(const LoDTensor& vec, UnaryOperation f) { + LoDTensor vec_y; + vec_y.Resize(vec.dims()); + const float* x = vec.data(); + float* y = vec_y.mutable_data(platform::CPUPlace()); + for (int64_t i = 0; i < vec.numel(); i++) { + y[i] = f(x[i]); + } + return vec_y; +} + +void tensor_apply_inplace(LoDTensor* vec, float (*f)(float)) { + float* data = vec->mutable_data(platform::CPUPlace()); + for (int64_t i = 0; i < vec->numel(); i++) { + data[i] = f(data[i]); + } +} + +template +LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, + BinaryOperation f) { + PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims()); + LoDTensor vec_y; + vec_y.Resize(vec_a.dims()); + const float* a = vec_a.data(); + const float* b = vec_b.data(); + float* y = vec_y.mutable_data(platform::CPUPlace()); + for (int64_t i = 0; i < vec_a.numel(); i++) { + y[i] = f(a[i], b[i]); + } + return vec_y; +} + +template +LoDTensor tensor_apply_eltwise_broadcast(const LoDTensor& vec_a, + const LoDTensor& vec_b, + BinaryOperation f) { + PADDLE_ENFORCE_EQ(vec_a.dims().size(), 2); + PADDLE_ENFORCE_EQ(vec_b.dims().size(), 2); + PADDLE_ENFORCE_EQ(vec_a.dims()[0], vec_b.dims()[0]); + PADDLE_ENFORCE_EQ(vec_b.dims()[1], 1); + LoDTensor vec_y; + vec_y.Resize(vec_a.dims()); + const float* a = vec_a.data(); + const float* b = vec_b.data(); + float* y = vec_y.mutable_data(platform::CPUPlace()); + size_t a_height = vec_a.dims()[0]; + size_t a_width = vec_a.dims()[1]; + for (size_t h = 0; h < a_height; h++) { + for (size_t w = 0; w < a_width; ++w) { + *(y++) = f(*(a++), b[h]); + } + } + return vec_y; +} + +// reshape to two dimensions {A, B * C * ...} +void make_tensor_2d(LoDTensor* tensor_to_reshape) { + auto dims_count = tensor_to_reshape->dims().size(); + PADDLE_ENFORCE_GT(dims_count, 0); + + int size2 = 1; + for (int i = 1; i < dims_count; i++) { + size2 *= tensor_to_reshape->dims()[i]; + } + tensor_to_reshape->Resize(make_ddim({tensor_to_reshape->dims()[0], size2})); +} + +void recompute_conv_weights(LoDTensor* weights, LoDTensor* tmp) { + // remember the weights tensor shape {A, B, C, ...} + auto weights_shape = weights->dims(); + // reduce the weights to 2d {A, B * C * ...} + make_tensor_2d(weights); + // make tmp tensor 2d by adding 1 as second dim {A, 1} + make_tensor_2d(tmp); + + *weights = + tensor_apply_eltwise_broadcast(*weights, *tmp, std::multiplies()); + // reshape weights to the original dims {A, B, C, ...} + weights->Resize(weights_shape); +} + +void recompute_bias_and_weights(const Scope* scope, + ir::Node* conv_weight, // + const ir::Node& bn_scale, // + const LoDTensor& bn_bias_tensor, // + const ir::Node& bn_mean, // + const ir::Node& bn_variance, // + LoDTensor* eltwise_y_in_tensor, // + float epsilon) { + // Re-compute bias of conv2d from BN + PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims()); + + auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable(); + auto* variance_tensor = + scope->FindVar(bn_variance.Name())->GetMutable(); + auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable(); + + auto std_tensor = LoDTensor(); + std_tensor.Resize(bn_bias_tensor.dims()); + std_tensor = + tensor_apply(*variance_tensor, [&](float x) { return x + epsilon; }); + + using EigenVectorArrayMap = + Eigen::Map>; + + EigenVectorArrayMap std_vec( + std_tensor.mutable_data(platform::CPUPlace()), std_tensor.numel(), + 1); + std_vec = std_vec.sqrt(); + auto tmp_tensor = + tensor_apply_eltwise(*scale_tensor, std_tensor, std::divides()); + auto tensor_minus = tensor_apply_eltwise(*eltwise_y_in_tensor, *mean_tensor, + std::minus()); + auto tensor_mul = + tensor_apply_eltwise(tensor_minus, tmp_tensor, std::multiplies()); + *eltwise_y_in_tensor = + tensor_apply_eltwise(tensor_mul, bn_bias_tensor, std::plus()); + + // Re-compute weight of conv2d from BN + auto* current_param = + scope->FindVar(conv_weight->Name())->GetMutable(); + recompute_conv_weights(current_param, &tmp_tensor); +} + +std::unique_ptr ConvBNFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); + conv_bn_pattern(conv_input, false /*with_eltwise_add*/); + + int found_conv_bn_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBN fuse"; + + // conv, batch_norm, + // conv_weight, conv_out, + // bn_scale, bn_bias, bn_mean, bn_variance, + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance + GET_CONV_BN_NODES(conv_bn_pattern); + + // Create eltwise_y (conv bias) variable + VarDesc eltwise_y_in_desc( + patterns::PDNodeName(name_scope_, "eltwise_y_in")); + auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); + auto* eltwise_y_in_tensor = + scope->Var(eltwise_y_in_node->Name())->GetMutable(); + + // Get batch norm bias + auto* bn_bias_tensor = + scope->FindVar(bn_bias->Name())->GetMutable(); + + // Initialize eltwise_y + eltwise_y_in_tensor->Resize(bn_bias_tensor->dims()); + std::fill_n(eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 0.0f); + + // update weights and biases + float epsilon = boost::get(batch_norm->Op()->GetAttr("epsilon")); + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, + *bn_mean, *bn_variance, eltwise_y_in_tensor, + epsilon); + + // Create an elementwise add node + OpDesc desc; + desc.SetInput("X", std::vector({conv_out->Name()})); + desc.SetInput("Y", std::vector({eltwise_y_in_node->Name()})); + desc.SetOutput("Out", std::vector({bn_out->Name()})); + desc.SetType("elementwise_add"); + desc.SetAttr("axis", 1); + bool a = boost::get(conv->Op()->GetAttr("use_mkldnn")); + desc.SetAttr("use_mkldnn", a); + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. + + GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance, + batch_norm, bn_mean_out, bn_variance_out, + bn_saved_mean, bn_saved_variance}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(conv_out, eltwise_op); + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); + IR_NODE_LINK_TO(eltwise_op, bn_out); + + found_conv_bn_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_bn_count); + return graph; +} + +std::unique_ptr ConvEltwiseAddBNFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); + conv_bn_pattern(conv_input, true /*with_eltwise_add*/); + + int found_conv_bn_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBN fuse"; + + // conv, batch_norm, + // conv_weight, conv_out, + // bn_scale, bn_bias, bn_mean, bn_variance, + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance + GET_CONV_BN_NODES(conv_bn_pattern); + // OPERATORS + GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern); + // BIAS inputs + GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern); + // BIAS outputs + GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern); + + // Get eltwise_y (conv bias) variable + auto* eltwise_y_in_tensor = + scope->FindVar(eltwise_y_in->Name())->GetMutable(); + + // Get batch norm bias + auto* bn_bias_tensor = + scope->FindVar(bn_bias->Name())->GetMutable(); + + // update weights and biases + float epsilon = boost::get(batch_norm->Op()->GetAttr("epsilon")); + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, + *bn_mean, *bn_variance, eltwise_y_in_tensor, + epsilon); + + // Update the elementwise_add node + eltwise->Op()->SetAttr("axis", 1); + eltwise->Op()->SetOutput("Out", std::vector({bn_out->Name()})); + + GraphSafeRemoveNodes( + graph.get(), + {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, + bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(eltwise, bn_out); + + found_conv_bn_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_bn_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass); +REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass, + paddle::framework::ir::ConvEltwiseAddBNFusePass); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..2c9eb574fe8e054e0ae221f08f664b91f05d95c9 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -0,0 +1,49 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp. + */ +class ConvBNFusePass : public FusePassBase { + public: + virtual ~ConvBNFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_bn_fuse"}; +}; + +class ConvEltwiseAddBNFusePass : public FusePassBase { + public: + virtual ~ConvEltwiseAddBNFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_eltwiseadd_bn_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 46c6a52c09e896596aa6d8e1e901955a68a4957d..8625b562e7dfab5a65692863cdc22b62ce15d758 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -626,6 +626,112 @@ bool VarLinksFromOp(Node *node, const std::string &op_type) { return false; } +PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, + bool with_eltwise_add) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + + PDNode *eltwise_op = nullptr; + if (with_eltwise_add) { + eltwise_op = + pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); + } + auto *batch_norm_op = + pattern->NewNode(batch_norm_repr())->assert_is_op("batch_norm"); + // Create variables + // Conv Filter + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d"); + + PDNode *eltwise_y_in_var = nullptr; + PDNode *eltwise_out_var = nullptr; + if (with_eltwise_add) { + // Conv output as Bias input + conv_out_var->assert_is_op_input("elementwise_add", "X"); + // Bias + eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); + eltwise_out_var = pattern->NewNode(eltwise_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("elementwise_add"); + } else { + // Conv output as BN input + conv_out_var->assert_is_op_input("batch_norm", "X"); + } + + // BN Scale + auto *bn_scale_var = pattern->NewNode(bn_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Scale"); + // BN Bias + auto *bn_bias_var = pattern->NewNode(bn_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Bias"); + // BN Mean + auto *bn_mean_var = pattern->NewNode(bn_mean_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Mean"); + // BN Variance + auto *bn_variance_var = pattern->NewNode(bn_variance_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("batch_norm", "Variance"); + + // BN output + auto *bn_out_var = pattern->NewNode(bn_out_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm"); + + auto *bn_mean_out_var = pattern->NewNode(bn_mean_out_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "MeanOut"); + + auto *bn_variance_out_var = + pattern->NewNode(bn_variance_out_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "VarianceOut"); + + auto *bn_saved_mean_var = + pattern->NewNode(bn_saved_mean_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedMean"); + + auto *bn_saved_variance_var = + pattern->NewNode(bn_saved_variance_repr()) + ->AsOutput() + ->assert_is_op_output("batch_norm", "SavedVariance"); + + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); + + if (with_eltwise_add) { + eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var}) + .LinksTo({eltwise_out_var}); + batch_norm_op + ->LinksFrom({eltwise_out_var, bn_scale_var, bn_bias_var, bn_mean_var, + bn_variance_var}) + .LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var, + bn_saved_mean_var, bn_saved_variance_var}); + } else { + batch_norm_op + ->LinksFrom({conv_out_var, bn_scale_var, bn_bias_var, bn_mean_var, + bn_variance_var}) + .LinksTo({bn_out_var, bn_mean_out_var, bn_variance_out_var, + bn_saved_mean_var, bn_saved_variance_var}); + } + return bn_out_var; +} + PDNode *patterns::ConvReLU::operator()( paddle::framework::ir::PDNode *conv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 508113bf4fcab274394f2705c36eddbf4ba3c77a..cdd6413d968b065453177ff78b0aad641a09f6e7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -375,6 +375,44 @@ struct PatternBase { size_t id_; }; +// Conv with batch norm +// op: conv + (elementwise_add +) batch_norm +// named nodes: +// conv_weight, conv_out, conv, +// bn_x, bn_scale, bn_bias, bn_mean, bn_variance, +// bn_batch_norm, bn_y, bn_mean_out, bn_variance_out, +// bn_saved_mean, bn_saved_variance +struct ConvBN : public PatternBase { + ConvBN(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_bn") {} + + PDNode* operator()(PDNode* conv_input, bool with_eltwise_add); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(batch_norm); + PATTERN_DECL_NODE(eltwise); // ELEMENTWISE_ADD + // CONV inputs + PATTERN_DECL_NODE(conv_weight); // Filter + // CONV outputs + PATTERN_DECL_NODE(conv_out); // tmp + // ELTWISE inputs + PATTERN_DECL_NODE(eltwise_y_in); + // ELTWISE outputs + PATTERN_DECL_NODE(eltwise_out); // tmp + // BN inputs + PATTERN_DECL_NODE(bn_scale); + PATTERN_DECL_NODE(bn_bias); + PATTERN_DECL_NODE(bn_mean); + PATTERN_DECL_NODE(bn_variance); + // BN outputs + PATTERN_DECL_NODE(bn_out); // Out + PATTERN_DECL_NODE(bn_mean_out); + PATTERN_DECL_NODE(bn_variance_out); + PATTERN_DECL_NODE(bn_saved_mean); + PATTERN_DECL_NODE(bn_saved_variance); +}; + // CONV with ReLU // op: conv + relu // named nodes: diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 53d39513f3686cea59e2d56ff62eec9869f3b2de..ba10687d65cfbbac89cfc76879c8b202ebd03229 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -146,5 +146,22 @@ void NaiveExecutor::CleanFeedFetchOps() { ops_.swap(ops); } +void NaiveExecutor::EnableMKLDNN(const ProgramDesc &program) { +#ifdef PADDLE_WITH_MKLDNN + VLOG(3) << "use_mkldnn=True"; + for (size_t block_id = 0; block_id < program.Size(); ++block_id) { + auto *block = const_cast(program).MutableBlock(block_id); + for (auto *op : block->AllOps()) { + if (op->HasAttr("use_mkldnn")) { + op->SetAttr("use_mkldnn", true); + } + } + } +#else + LOG(WARNING) + << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index 9355e9e36a6358aa91553dca35aaf1b658516a0a..9374f3f4a35cc0f90e5b2d6e8b397784b8eae123 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -14,6 +14,8 @@ #pragma once +#include +#include #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -46,6 +48,8 @@ class NaiveExecutor { void CleanFeedFetchOps(); + void EnableMKLDNN(const ProgramDesc& program); + protected: void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id); diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 17f942571d0141537e992be9ab73847d2a794698..b29ac44699463312a1fdcea55e003daa75997302 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -50,6 +50,27 @@ class CompileTimeInferShapeContext : public InferShapeContext { const std::vector &Outputs( const std::string &name) const override; + void ShareDim(const std::string &in, const std::string &out, size_t i = 0, + size_t j = 0) override { + PADDLE_ENFORCE_LT(i, Inputs(in).size()); + PADDLE_ENFORCE_LT(j, Outputs(out).size()); + const std::string &input_n = Inputs(in)[i]; + const std::string &output_n = Outputs(out)[j]; + + PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@", + in, i); + PADDLE_ENFORCE(output_n != framework::kEmptyVarName, + "The %s[%d] is @EMPTY@", out, j); + + auto *in_var = block_.FindVarRecursive(input_n); + auto *out_var = block_.FindVarRecursive(output_n); + + PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(), + "The type of %s and %s is not the same.", input_n, output_n); + + SetDim(output_n, GetDim(input_n)); + } + void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a103be7191d02a96ee97d76f786f9364938c1c65..9f930065324f13f5aa79c214e820fb6fc2f3a166 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -542,13 +542,45 @@ class RuntimeInferShapeContext : public InferShapeContext { return op_.Outputs(name); } - void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const override { + void ShareDim(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) override { PADDLE_ENFORCE_LT(i, Inputs(in).size()); PADDLE_ENFORCE_LT(j, Outputs(out).size()); - Variable* in_var = scope_.FindVar(Inputs(in)[i]); - Variable* out_var = scope_.FindVar(Outputs(out)[j]); + const std::string& input_n = Inputs(in)[i]; + const std::string& output_n = Outputs(out)[j]; + + Variable* in_var = scope_.FindVar(input_n); + Variable* out_var = scope_.FindVar(output_n); + PADDLE_ENFORCE(in_var->Type() == out_var->Type(), + "The type of %s and %s is not the same.", output_n, + GetDim(input_n)); + + if (in_var->IsType()) { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } else if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + PADDLE_THROW( + "Currently, the input type of ShareDim only can be LoDTensor " + "or SelectedRows."); + } + } + + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + const std::vector& inputs = Inputs(in); + const std::vector& outputs = Outputs(out); + PADDLE_ENFORCE_LT(i, inputs.size()); + PADDLE_ENFORCE_LT(j, outputs.size()); + Variable* in_var = scope_.FindVar(inputs.at(i)); if (!in_var->IsType()) return; + Variable* out_var = scope_.FindVar(outputs.at(j)); PADDLE_ENFORCE(out_var->IsType(), "The %d-th output of Output(%s) must be LoDTensor.", j, out); auto in_tensor = in_var->Get(); @@ -576,20 +608,6 @@ class RuntimeInferShapeContext : public InferShapeContext { out_tensor->set_layout(in_tensor.layout()); } - void ShareLayout(const std::string& in, const std::string& out, size_t i = 0, - size_t j = 0) const { - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - Variable* in_var = scope_.FindVar(Inputs(in)[i]); - Variable* out_var = scope_.FindVar(Outputs(out)[j]); - if (!in_var->IsType()) return; - PADDLE_ENFORCE(out_var->IsType(), - "The %d-th output of Output(%s) must be LoDTensor.", j, out); - auto in_tensor = in_var->Get(); - auto* out_tensor = out_var->GetMutable(); - out_tensor->set_layout(in_tensor.layout()); - } - bool IsRuntime() const override { return true; } protected: diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h index da163835e8652ae479121bd67f2eed77332b2740..dbf00f3a79f7d1dcf97b346fccfdb68f119d4aa3 100644 --- a/paddle/fluid/framework/rw_lock.h +++ b/paddle/fluid/framework/rw_lock.h @@ -46,6 +46,7 @@ struct RWLock { private: pthread_rwlock_t lock_; }; +// TODO(paddle-dev): Support RWLock for WIN32 for correctness. #else // https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive // In windows, rw_lock seems like a hack. Use empty object and do nothing. diff --git a/paddle/fluid/framework/shape_inference.cc b/paddle/fluid/framework/shape_inference.cc index 89eb00ff65598eff5f4ba541df107e8da04e1a89..ddff2c7c261746ac9986e79cff3da7e0a9654adc 100644 --- a/paddle/fluid/framework/shape_inference.cc +++ b/paddle/fluid/framework/shape_inference.cc @@ -46,16 +46,6 @@ std::vector InferShapeContext::GetReaderDims( return this->GetRepeatedDims(arg_names[0]); } -void InferShapeContext::ShareLoDs(const std::string &in, - const std::string &out) const { - PADDLE_ENFORCE_EQ(Inputs(in).size(), Outputs(out).size(), - "The number of arguments in %s and %s is not equal.", in, - out); - for (size_t i = 0; i < in.size(); ++i) { - ShareLoD(in, out, i, i); - } -} - DDim InferShapeContext::GetInputsElementDim(const std::string &name, int idx) const { const std::vector &names = Inputs(name); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index fd220d961af85dd55fe2031409180823d8f178fc..280bc19dce7b604d67aefdc572de96b479b8d2d7 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -56,7 +56,8 @@ class InferShapeContext { virtual const std::vector &Outputs( const std::string &name) const = 0; - void ShareLoDs(const std::string &in, const std::string &out) const; + virtual void ShareDim(const std::string &in, const std::string &out, + size_t i = 0, size_t j = 0) = 0; virtual void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const = 0; diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 05c4a17a01c6fabe48f3fe18544c13153feb0673..1d7a2eb5b38255531880fe3d2e5321024caf0c6b 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -165,10 +165,12 @@ inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, } template -struct AnyVisitor : public boost::static_visitor { +class AnyVisitor : public boost::static_visitor { + private: const framework::Tensor& tensor_; Predicate predicate_; + public: AnyVisitor(const framework::Tensor& tensor, Predicate predicate) : tensor_(tensor), predicate_(std::move(predicate)) {} @@ -206,6 +208,27 @@ struct AnyVisitor : public boost::static_visitor { } }; +template +class AnyOutVisitor : public boost::static_visitor<> { + private: + const framework::Tensor& tensor_; + mutable framework::Tensor* out_; + Predicate predicate_; + + public: + AnyOutVisitor(const framework::Tensor& tensor, Predicate predicate, + framework::Tensor* out) + : tensor_(tensor), out_(out), predicate_(std::move(predicate)) {} + + template + void operator()(const Place& place) const { + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + out_->Resize({1}); + out_->mutable_data(place); + AnyImpl(predicate_, tensor_, *ctx, out_); + } +}; + template inline bool Any(const framework::Tensor& tensor, Predicate predicate) { AnyVisitor visitor(tensor, predicate); @@ -213,6 +236,14 @@ inline bool Any(const framework::Tensor& tensor, Predicate predicate) { return platform::VisitPlace(place, visitor); } +template +inline void Any(const framework::Tensor& tensor, Predicate predicate, + framework::Tensor* out) { + AnyOutVisitor visitor(tensor, predicate, out); + auto place = tensor.place(); + platform::VisitPlace(place, visitor); +} + struct ContainsNANPredicate { template auto operator()(const T& eigen_vec) const @@ -227,6 +258,12 @@ bool TensorContainsNAN(const framework::Tensor& tensor) { return Any(tensor, predicate); } +void TensorContainsNAN(const framework::Tensor& tensor, + framework::Tensor* out) { + ContainsNANPredicate predicate; + Any(tensor, predicate, out); +} + struct ContainsInfPredicate { template auto operator()(const T& eigen_vec) const @@ -241,6 +278,71 @@ bool TensorContainsInf(const framework::Tensor& tensor) { return Any(tensor, predicate); } +void TensorContainsInf(const framework::Tensor& tensor, + framework::Tensor* out) { + ContainsInfPredicate predicate; + Any(tensor, predicate, out); +} + +// NOTE(dzhwinter): +// Isfinite need a AllVisitor to loop through all the elements. +// We choose two cuda call instead of one allvisitor. The AllVisitor +// should be implemented if the performance hurts. +bool TensorIsfinite(const framework::Tensor& tensor) { + ContainsInfPredicate pred_inf; + ContainsNANPredicate pred_nan; + return !Any(tensor, pred_inf) && !Any(tensor, pred_nan); +} + +#ifdef PADDLE_WITH_CUDA +template +static inline void __global__ BothFalse(const T* cmp, T* out) { + out[0] = (!cmp[0]) && (!out[0]); +} +#endif + +struct BothFalseVisitor : public boost::static_visitor<> { + const framework::Tensor& in_; + mutable framework::Tensor* out_; + BothFalseVisitor(const framework::Tensor& in, framework::Tensor* out) + : in_(in), out_(out) {} + + template + void operator()(const Place& place) const { + VisitorImpl(place); + } + + void VisitorImpl(const platform::CUDAPlace& gpu) const { +#ifdef PADDLE_WITH_CUDA + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(gpu); + BothFalse<<<1, 1, 0, ctx->stream()>>>(in_.data(), + out_->mutable_data(gpu)); +#endif + } + + void VisitorImpl(const platform::CPUPlace& cpu) const { + bool lhs = !in_.data()[0]; + bool rhs = !out_->mutable_data(cpu)[0]; + out_->mutable_data(cpu)[0] = lhs && rhs; + } + + void VisitorImpl( + const platform::CUDAPinnedPlace& cpu /* equals to cpu*/) const { + bool lhs = !in_.data()[0]; + bool rhs = !out_->mutable_data(cpu)[0]; + out_->mutable_data(cpu)[0] = lhs && rhs; + } +}; + +void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) { + framework::Tensor tmp; + TensorContainsInf(tensor, &tmp); + TensorContainsNAN(tensor, out); + BothFalseVisitor visitor(tmp, out); + auto place = tensor.place(); + platform::VisitPlace(place, visitor); +} + void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { { // the 1st field, uint32_t version diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 4457382ade37a12f5f3613fc4113fbf1f6f91124..cab6d9b67e4e64335be0a386bfffb7ebe4373b3e 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -57,8 +57,15 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, template void TesnorToVector(const Tensor& src, std::vector* dst); +// copy the result bool to cpu bool TensorContainsNAN(const framework::Tensor& tensor); bool TensorContainsInf(const framework::Tensor& tensor); +bool TensorIsfinite(const framework::Tensor& tensor); + +// store the result bool in gpu tensor, async operation. Faster than above ones. +void TensorContainsNAN(const framework::Tensor& tensor, framework::Tensor* out); +void TensorContainsInf(const framework::Tensor& tensor, framework::Tensor* out); +void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out); void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx); diff --git a/paddle/fluid/framework/tensor_util_test.cc b/paddle/fluid/framework/tensor_util_test.cc index 6e10885890cd2d4a0d77834944b37e291197b637..a1e5b967a86d10f3439db662af54bb82888027b9 100644 --- a/paddle/fluid/framework/tensor_util_test.cc +++ b/paddle/fluid/framework/tensor_util_test.cc @@ -36,7 +36,7 @@ TEST(TensorCopy, Tensor) { TensorCopy(src_tensor, *cpu_place, &dst_tensor); const int* dst_ptr = dst_tensor.data(); - ASSERT_NE(src_ptr, dst_ptr); + EXPECT_NE(src_ptr, dst_ptr); for (size_t i = 0; i < 9; ++i) { EXPECT_EQ(src_ptr[i], dst_ptr[i]); } @@ -47,7 +47,7 @@ TEST(TensorCopy, Tensor) { TensorCopy(slice_tensor, *cpu_place, &dst_tensor); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); - ASSERT_NE(dst_ptr, slice_ptr); + EXPECT_NE(dst_ptr, slice_ptr); for (size_t i = 0; i < 3; ++i) { EXPECT_EQ(dst_ptr[i], slice_ptr[i]); } @@ -77,7 +77,7 @@ TEST(TensorCopy, Tensor) { // Sync before Compare Tensors gpu_ctx.Wait(); const int* dst_ptr = dst_tensor.data(); - ASSERT_NE(src_ptr, dst_ptr); + EXPECT_NE(src_ptr, dst_ptr); for (size_t i = 0; i < 9; ++i) { EXPECT_EQ(src_ptr[i], dst_ptr[i]); } @@ -94,7 +94,7 @@ TEST(TensorCopy, Tensor) { gpu_ctx.Wait(); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); - ASSERT_NE(dst_ptr, slice_ptr); + EXPECT_NE(dst_ptr, slice_ptr); for (size_t i = 0; i < 3; ++i) { EXPECT_EQ(dst_ptr[i], slice_ptr[i]); } @@ -117,7 +117,7 @@ TEST(TensorFromVector, Tensor) { // Compare Tensors const int* cpu_ptr = cpu_tensor.data(); const int* src_ptr = src_vec.data(); - ASSERT_NE(src_ptr, cpu_ptr); + EXPECT_NE(src_ptr, cpu_ptr); for (size_t i = 0; i < 9; ++i) { EXPECT_EQ(src_ptr[i], cpu_ptr[i]); } @@ -127,7 +127,7 @@ TEST(TensorFromVector, Tensor) { paddle::framework::TensorFromVector(src_vec, &cpu_tensor); cpu_ptr = cpu_tensor.data(); src_ptr = src_vec.data(); - ASSERT_NE(src_ptr, cpu_ptr); + EXPECT_NE(src_ptr, cpu_ptr); for (size_t i = 0; i < 5; ++i) { EXPECT_EQ(src_ptr[i], cpu_ptr[i]); } @@ -161,8 +161,8 @@ TEST(TensorFromVector, Tensor) { const int* src_ptr = src_vec.data(); const int* cpu_ptr = cpu_tensor.data(); const int* dst_ptr = dst_tensor.data(); - ASSERT_NE(src_ptr, cpu_ptr); - ASSERT_NE(src_ptr, dst_ptr); + EXPECT_NE(src_ptr, cpu_ptr); + EXPECT_NE(src_ptr, dst_ptr); for (size_t i = 0; i < 9; ++i) { EXPECT_EQ(src_ptr[i], cpu_ptr[i]); EXPECT_EQ(src_ptr[i], dst_ptr[i]); @@ -181,8 +181,8 @@ TEST(TensorFromVector, Tensor) { src_ptr = src_vec.data(); cpu_ptr = cpu_tensor.data(); dst_ptr = dst_tensor.data(); - ASSERT_NE(src_ptr, cpu_ptr); - ASSERT_NE(src_ptr, dst_ptr); + EXPECT_NE(src_ptr, cpu_ptr); + EXPECT_NE(src_ptr, dst_ptr); for (size_t i = 0; i < 5; ++i) { EXPECT_EQ(src_ptr[i], cpu_ptr[i]); EXPECT_EQ(src_ptr[i], dst_ptr[i]); @@ -235,9 +235,9 @@ TEST(TensorContainsNAN, CPU) { buf[0] = 0.0; buf[1] = NAN; buf[2] = 0.0; - ASSERT_TRUE(paddle::framework::TensorContainsNAN(src)); + EXPECT_TRUE(paddle::framework::TensorContainsNAN(src)); buf[1] = 0.0; - ASSERT_FALSE(paddle::framework::TensorContainsNAN(src)); + EXPECT_FALSE(paddle::framework::TensorContainsNAN(src)); } { @@ -248,9 +248,9 @@ TEST(TensorContainsNAN, CPU) { buf[0] = 0.0; buf[1].x = 0x7fff; buf[2] = 0.0; - ASSERT_TRUE(paddle::framework::TensorContainsNAN(src)); + EXPECT_TRUE(paddle::framework::TensorContainsNAN(src)); buf[1] = 0.0; - ASSERT_FALSE(paddle::framework::TensorContainsNAN(src)); + EXPECT_FALSE(paddle::framework::TensorContainsNAN(src)); } } @@ -261,9 +261,9 @@ TEST(TensorContainsInf, CPU) { buf[0] = 1.0; buf[1] = INFINITY; buf[2] = 0.0; - ASSERT_TRUE(paddle::framework::TensorContainsInf(src)); + EXPECT_TRUE(paddle::framework::TensorContainsInf(src)); buf[1] = 1.0; - ASSERT_FALSE(paddle::framework::TensorContainsInf(src)); + EXPECT_FALSE(paddle::framework::TensorContainsInf(src)); } { @@ -274,9 +274,55 @@ TEST(TensorContainsInf, CPU) { buf[0] = 1.0; buf[1].x = 0x7c00; buf[2] = 0.0; - ASSERT_TRUE(paddle::framework::TensorContainsInf(src)); + EXPECT_TRUE(paddle::framework::TensorContainsInf(src)); buf[1] = 1.0; - ASSERT_FALSE(paddle::framework::TensorContainsInf(src)); + EXPECT_FALSE(paddle::framework::TensorContainsInf(src)); + } +} + +TEST(TensorIsfinite, CPU) { + { + paddle::framework::Tensor src, out; + double* buf = src.mutable_data({3}, paddle::platform::CPUPlace()); + buf[0] = 1.0; + buf[1] = INFINITY; + buf[2] = 0.0; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], false); + buf[1] = 1.0; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], true); + } + + { + paddle::framework::Tensor src, out; + double* buf = src.mutable_data({3}, paddle::platform::CPUPlace()); + buf[0] = 1.0; + buf[1] = NAN; + buf[2] = 0.0; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], false); + buf[1] = 1.0; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], true); + } + + { + paddle::framework::Tensor src, out; + paddle::platform::float16* buf = + src.mutable_data( + {3}, paddle::platform::CPUPlace()); + buf[0] = 1.0; + buf[1].x = 0x7c00; + buf[2] = 0.0; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], false); + buf[1] = 1.0; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], true); + buf[1].x = 0x7fff; + paddle::framework::TensorIsfinite(src, &out); + EXPECT_EQ(out.data()[0], false); } } @@ -299,9 +345,9 @@ TEST(Tensor, FromAndToStream) { TensorFromStream(iss, &dst_tensor, cpu_ctx); int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < 5; ++i) { - ASSERT_EQ(dst_ptr[i], array[i]); + EXPECT_EQ(dst_ptr[i], array[i]); } - ASSERT_EQ(dst_tensor.dims(), src_tensor.dims()); + EXPECT_EQ(dst_tensor.dims(), src_tensor.dims()); delete place; } #ifdef PADDLE_WITH_CUDA @@ -323,7 +369,7 @@ TEST(Tensor, FromAndToStream) { int* dst_ptr = dst_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < 6; ++i) { - ASSERT_EQ(dst_ptr[i], array[i]); + EXPECT_EQ(dst_ptr[i], array[i]); } delete gpu_place; } diff --git a/paddle/fluid/framework/tensor_util_test.cu b/paddle/fluid/framework/tensor_util_test.cu index b4cff1e6c2293fa44f0fd0bb398a538c08dd4fb1..a51f74199e714b8606c9766c57bc6b1dc4c73c65 100644 --- a/paddle/fluid/framework/tensor_util_test.cu +++ b/paddle/fluid/framework/tensor_util_test.cu @@ -27,9 +27,9 @@ static __global__ void FillNAN(float* buf) { } static __global__ void FillInf(float* buf) { - buf[0] = 0.0; - buf[1] = INFINITY; - buf[2] = 0.5; + buf[0] = INFINITY; + buf[1] = 0.1; + buf[2] = 0.2; } static __global__ void FillNAN(platform::float16* buf) { @@ -44,6 +44,18 @@ static __global__ void FillInf(platform::float16* buf) { buf[2] = 0.5; } +static __global__ void FillFinite(float* buf) { + buf[0] = 0.0; + buf[1] = 0.1; + buf[2] = 0.2; +} + +static __global__ void FillFinite(platform::float16* buf) { + buf[0] = 0.0; + buf[1] = 0.1; + buf[2] = 0.2; +} + TEST(TensorContainsNAN, GPU) { paddle::platform::CUDAPlace gpu(0); auto& pool = paddle::platform::DeviceContextPool::Instance(); @@ -86,5 +98,163 @@ TEST(TensorContainsInf, GPU) { } } +TEST(TensorIsfinite, GPU) { + paddle::platform::CUDAPlace gpu(0); + using paddle::platform::float16; + auto& pool = paddle::platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + // contains inf + { + Tensor tensor; + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + EXPECT_TRUE(!TensorIsfinite(tensor)); + } + { + Tensor tensor; + float16* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + EXPECT_TRUE(!TensorIsfinite(tensor)); + } + + // contains nan + { + Tensor tensor; + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + EXPECT_TRUE(!TensorIsfinite(tensor)); + } + { + Tensor tensor; + float16* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + EXPECT_TRUE(!TensorIsfinite(tensor)); + } + + // all element are finite + { + Tensor tensor; + float* buf = tensor.mutable_data({3}, gpu); + FillFinite<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + EXPECT_TRUE(TensorIsfinite(tensor)); + } + { + Tensor tensor; + float16* buf = tensor.mutable_data({3}, gpu); + FillFinite<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + EXPECT_TRUE(TensorIsfinite(tensor)); + } +} + +TEST(TensorContainsInf, GPUWithoutWait) { + paddle::platform::CUDAPlace gpu(0); + auto& pool = paddle::platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + { + Tensor tensor, out; + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorContainsInf(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + ASSERT_EQ(tmp.data()[0], true); + } + { + Tensor tensor, out; + paddle::platform::float16* buf = + tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorContainsInf(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + ASSERT_EQ(tmp.data()[0], true); + } +} + +TEST(TensorContainsNAN, GPUWithoutWait) { + paddle::platform::CUDAPlace gpu(0); + auto& pool = paddle::platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + { + Tensor tensor, out; + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorContainsNAN(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + ASSERT_EQ(tmp.data()[0], true); + } + { + Tensor tensor, out; + paddle::platform::float16* buf = + tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorContainsNAN(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + ASSERT_EQ(tmp.data()[0], true); + } +} + +TEST(TensorIsfinite, GPUWithoutWait) { + paddle::platform::CUDAPlace gpu(0); + auto& pool = paddle::platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + { + Tensor tensor, out; + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorIsfinite(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + EXPECT_EQ(tmp.data()[0], false); + } + { + Tensor tensor, out; + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorIsfinite(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + EXPECT_EQ(tmp.data()[0], false); + } + { + Tensor tensor, out; + float* buf = tensor.mutable_data({3}, gpu); + FillFinite<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + TensorIsfinite(tensor, &out); + platform::CPUPlace cpu; + Tensor tmp; + TensorCopy(out, cpu, *cuda_ctx, &tmp); + cuda_ctx->Wait(); + EXPECT_EQ(tmp.data()[0], true); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index ec1bc7825dd21628f5c37ea44a154abe7b7e8c73..9794a193bcfaae19552b1f6fbdf2dab2898033d5 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -19,9 +19,19 @@ cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api) add_subdirectory(api) +set(STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor) +set(SHARED_INFERENCE_SRCS + io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc + ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc) +if (WITH_GPU AND TENSORRT_FOUND) + set(STATIC_INFERENCE_APIS ${STATIC_INFERENCE_APIS} paddle_inference_tensorrt_subgraph_engine) + set(SHARED_INFERENCE_SRCS ${SHARED_INFERENCE_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/api/api_tensorrt_subgraph_engine.cc) +endif() + # Create static library -cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api - analysis_predictor zero_copy_tensor) +cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor) + if(NOT APPLE) # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym") @@ -29,10 +39,7 @@ if(NOT APPLE) endif() # Create shared library -cc_library(paddle_fluid_shared SHARED - SRCS io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc - ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc - ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc +cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} DEPS ${fluid_modules} paddle_fluid_api) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index c740ea009f6cfc2ea250d8f1abdd7d442c2a0bb0..d4d2fd4634f9e11f3f002e11e177c332ced49885 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -20,8 +20,6 @@ cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid) -set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) - function (inference_analysis_test TARGET) if(WITH_TESTING) set(options "") diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 0aa9367bf5692e53e2a1f1247523cf9a4f0b3a1d..765145cb7da44ca13c5394ad1dc2e879e69d69d1 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -64,15 +64,17 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // - "attention_lstm_fuse_pass", // - "embedding_fc_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "fc_fuse_pass", // + "infer_clean_graph_pass", // + "attention_lstm_fuse_pass", // + "embedding_fc_lstm_fuse_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // #ifdef PADDLE_WITH_MKLDNN "conv_relu_mkldnn_fuse_pass", // #endif diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 32d58b87413c95908644ffba31bbec22d8e23201..0ddd5d53f836131fe37d412fc867cb38f11ee2b5 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -31,7 +31,6 @@ function(inference_api_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) cc_test(${TARGET_NAME} SRCS ${inference_test_SRC} DEPS "${inference_deps}" diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a153433d29b6fef7abdbf7b7b446bad40c1d71e6..3bc6af5241c41bd805699121d614d431d46d863f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -71,6 +71,11 @@ bool AnalysisPredictor::Init( } else { inference_program_ = program; } + + if (config_._use_mkldnn) { + executor_->EnableMKLDNN(*inference_program_); + } + executor_->Prepare(scope_.get(), *inference_program_, 0, config_.use_feed_fetch_ops); @@ -92,6 +97,7 @@ bool AnalysisPredictor::Run(const std::vector &inputs, LOG(ERROR) << "fail to set feed"; return false; } + // Run the inference program // if share variables, we need not create variables executor_->Run(); diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index d4e6bb3e4a4ceb361ccd35121d0ecf84a764243e..ec8471ef960a2fc44af23c52be09cd678fab3f70 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -3,6 +3,7 @@ project(cpp_inference_demo CXX C) option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON) option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF) option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON) +option(USE_TENSORRT "Compile demo with TensorRT." OFF) macro(safe_set_static_flag) foreach(flag_var @@ -60,6 +61,13 @@ endif(NOT WIN32) include_directories("${PADDLE_LIB}/third_party/boost") include_directories("${PADDLE_LIB}/third_party/eigen3") +if (NOT WIN32) + if (USE_TENSORRT AND WITH_GPU) + include_directories("${TENSORRT_INCLUDE_DIR}") + link_directories("${TENSORRT_LIB_DIR}") + endif() +endif(NOT WIN32) + if (NOT WIN32) link_directories("${PADDLE_LIB}/third_party/install/snappy/lib") link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib") @@ -112,6 +120,10 @@ endif(NOT WIN32) if(WITH_GPU) if(NOT WIN32) + if (USE_TENSORRT) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX}) + endif() set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 44335a872f2e00b34e29a9e7601cb390a460362c..65c95f0834a9356fc14faed8342f5d1e474edf8f 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -3,6 +3,9 @@ PADDLE_ROOT=$1 TURN_ON_MKL=$2 # use MKL or Openblas TEST_GPU_CPU=$3 # test both GPU/CPU mode or only CPU mode DATA_DIR=$4 # dataset +TENSORRT_INCLUDE_DIR=$5 # TensorRT header file dir, defalut to /usr/local/TensorRT/include +TENSORRT_LIB_DIR=$6 # TensorRT lib file dir, default to /usr/local/TensorRT/lib + cd `dirname $0` current_dir=`pwd` if [ $2 == ON ]; then @@ -16,6 +19,11 @@ else use_gpu_list='false' fi +USE_TENSORRT=OFF +if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then + USE_TENSORRT=ON +fi + PREFIX=inference-vis-demos%2F URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX} @@ -86,5 +94,23 @@ for WITH_STATIC_LIB in ON OFF; do fi done done + + # --------tensorrt mobilenet------ + if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then + rm -rf * + cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \ + -DWITH_MKL=$TURN_ON_MKL \ + -DDEMO_NAME=trt_mobilenet_demo \ + -DWITH_GPU=$TEST_GPU_CPU \ + -DWITH_STATIC_LIB=$WITH_STATIC_LIB \ + -DUSE_TENSORRT=$USE_TENSORRT \ + -DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \ + -DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR + make -j + ./trt_mobilenet_demo \ + --modeldir=$DATA_DIR/mobilenet/model \ + --data=$DATA_DIR/mobilenet/data.txt \ + --refer=$DATA_DIR/mobilenet/result.txt + fi done set +x diff --git a/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc b/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc index 360f924810a570422db5a00b13939813fa73e2fa..8058d7e881025b1d806efe187d4523adadff367d 100644 --- a/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc +++ b/paddle/fluid/inference/api/demo_ci/simple_on_word2vec.cc @@ -22,8 +22,8 @@ limitations under the License. */ #include #include #include //NOLINT + #include "paddle/fluid/inference/paddle_inference_api.h" -#include "paddle/fluid/platform/enforce.h" DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_bool(use_gpu, false, "Whether use gpu."); @@ -62,17 +62,17 @@ void Main(bool use_gpu) { CHECK(predictor->Run(slots, &outputs)); //# 4. Get output. - PADDLE_ENFORCE(outputs.size(), 1UL); + CHECK_EQ(outputs.size(), 1UL); // Check the output buffer size and result of each tid. - PADDLE_ENFORCE(outputs.front().data.length(), 33168UL); + CHECK_EQ(outputs.front().data.length(), 33168UL); float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815, 0.000932706}; const size_t num_elements = outputs.front().data.length() / sizeof(float); // The outputs' buffers are in CPU memory. for (size_t i = 0; i < std::min(static_cast(5), num_elements); i++) { - PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], - result[i]); + CHECK_NEAR(static_cast(outputs.front().data.data())[i], result[i], + 0.001); } } } @@ -108,9 +108,9 @@ void MainThreads(int num_threads, bool use_gpu) { CHECK(predictor->Run(inputs, &outputs)); // 4. Get output. - PADDLE_ENFORCE(outputs.size(), 1UL); + CHECK_EQ(outputs.size(), 1UL); // Check the output buffer size and result of each tid. - PADDLE_ENFORCE(outputs.front().data.length(), 33168UL); + CHECK_EQ(outputs.front().data.length(), 33168UL); float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815, 0.000932706}; const size_t num_elements = @@ -118,8 +118,8 @@ void MainThreads(int num_threads, bool use_gpu) { // The outputs' buffers are in CPU memory. for (size_t i = 0; i < std::min(static_cast(5), num_elements); i++) { - PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], - result[i]); + CHECK_NEAR(static_cast(outputs.front().data.data())[i], + result[i], 0.001); } } }); diff --git a/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc b/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffb12b5871f088f15e43a1b0ff7e2a8b2f5fd079 --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/trt_mobilenet_demo.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file contains demo of mobilenet for tensorrt. + */ + +#include +#include // use glog instead of CHECK to avoid importing other paddle header files. +#include "paddle/fluid/inference/demo_ci/utils.h" + +DECLARE_double(fraction_of_gpu_memory_to_use); +DEFINE_string(modeldir, "", "Directory of the inference model."); +DEFINE_string(refer, "", "path to reference result for comparison."); +DEFINE_string( + data, "", + "path of data; each line is a record, format is " + "'\t predictor; + paddle::contrib::MixedRTConfig config; + config.param_file = FLAGS_modeldir + "/__params__"; + config.prog_file = FLAGS_modeldir + "/__model__"; + config.use_gpu = true; + config.device = 0; + config.max_batch_size = 1; + config.fraction_of_gpu_memory = 0.1; // set by yourself + predictor = CreatePaddlePredictor(config); + + VLOG(3) << "begin to process data"; + // Just a single batch of data. + std::string line; + std::ifstream file(FLAGS_data); + std::getline(file, line); + auto record = ProcessALine(line); + file.close(); + + // Inference. + PaddleTensor input; + input.shape = record.shape; + input.data = + PaddleBuf(record.data.data(), record.data.size() * sizeof(float)); + input.dtype = PaddleDType::FLOAT32; + + VLOG(3) << "run executor"; + std::vector output; + predictor->Run({input}, &output, 1); + + VLOG(3) << "output.size " << output.size(); + auto& tensor = output.front(); + VLOG(3) << "output: " << SummaryTensor(tensor); + + // compare with reference result + CheckOutput(FLAGS_refer, tensor); +} + +} // namespace demo +} // namespace paddle + +int main(int argc, char** argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + paddle::demo::Main(); + return 0; +} diff --git a/paddle/fluid/inference/api/demo_ci/utils.h b/paddle/fluid/inference/api/demo_ci/utils.h index cb8990671162dff47228736e69617229528cc093..4792c97fe7d0a3f9c904774ad4a8e580cefcf237 100644 --- a/paddle/fluid/inference/api/demo_ci/utils.h +++ b/paddle/fluid/inference/api/demo_ci/utils.h @@ -14,6 +14,8 @@ #pragma once #include +#include +#include #include #include #include "paddle/fluid/inference/paddle_inference_api.h" @@ -21,6 +23,11 @@ namespace paddle { namespace demo { +struct Record { + std::vector data; + std::vector shape; +}; + static void split(const std::string& str, char sep, std::vector* pieces) { pieces->clear(); @@ -39,6 +46,58 @@ static void split(const std::string& str, char sep, } } +Record ProcessALine(const std::string& line) { + VLOG(3) << "process a line"; + std::vector columns; + split(line, '\t', &columns); + CHECK_EQ(columns.size(), 2UL) + << "data format error, should be \t"; + + Record record; + std::vector data_strs; + split(columns[0], ' ', &data_strs); + for (auto& d : data_strs) { + record.data.push_back(std::stof(d)); + } + + std::vector shape_strs; + split(columns[1], ' ', &shape_strs); + for (auto& s : shape_strs) { + record.shape.push_back(std::stoi(s)); + } + VLOG(3) << "data size " << record.data.size(); + VLOG(3) << "data shape size " << record.shape.size(); + return record; +} + +void CheckOutput(const std::string& referfile, const PaddleTensor& output) { + std::string line; + std::ifstream file(referfile); + std::getline(file, line); + auto refer = ProcessALine(line); + file.close(); + + size_t numel = output.data.length() / PaddleDtypeSize(output.dtype); + VLOG(3) << "predictor output numel " << numel; + VLOG(3) << "reference output numel " << refer.data.size(); + CHECK_EQ(numel, refer.data.size()); + switch (output.dtype) { + case PaddleDType::INT64: { + for (size_t i = 0; i < numel; ++i) { + CHECK_EQ(static_cast(output.data.data())[i], refer.data[i]); + } + break; + } + case PaddleDType::FLOAT32: + for (size_t i = 0; i < numel; ++i) { + CHECK_LT( + fabs(static_cast(output.data.data())[i] - refer.data[i]), + 1e-5); + } + break; + } +} + /* * Get a summary of a PaddleTensor content. */ diff --git a/paddle/fluid/inference/api/demo_ci/vis_demo.cc b/paddle/fluid/inference/api/demo_ci/vis_demo.cc index 3800d49b34738d5a272033d75cb415ae9ad1fb8f..db61786e2fefda29256d84b5357028ec9c39b014 100644 --- a/paddle/fluid/inference/api/demo_ci/vis_demo.cc +++ b/paddle/fluid/inference/api/demo_ci/vis_demo.cc @@ -17,11 +17,8 @@ limitations under the License. */ */ #include -#include // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files. -#include -#include +#include // use glog instead of CHECK to avoid importing other paddle header files. #include "paddle/fluid/inference/demo_ci/utils.h" -#include "paddle/fluid/platform/enforce.h" #ifdef PADDLE_WITH_CUDA DECLARE_double(fraction_of_gpu_memory_to_use); @@ -37,70 +34,11 @@ DEFINE_bool(use_gpu, false, "Whether use gpu."); namespace paddle { namespace demo { -struct Record { - std::vector data; - std::vector shape; -}; - -void split(const std::string& str, char sep, std::vector* pieces); - -Record ProcessALine(const std::string& line) { - VLOG(3) << "process a line"; - std::vector columns; - split(line, '\t', &columns); - CHECK_EQ(columns.size(), 2UL) - << "data format error, should be \t"; - - Record record; - std::vector data_strs; - split(columns[0], ' ', &data_strs); - for (auto& d : data_strs) { - record.data.push_back(std::stof(d)); - } - - std::vector shape_strs; - split(columns[1], ' ', &shape_strs); - for (auto& s : shape_strs) { - record.shape.push_back(std::stoi(s)); - } - VLOG(3) << "data size " << record.data.size(); - VLOG(3) << "data shape size " << record.shape.size(); - return record; -} - -void CheckOutput(const std::string& referfile, const PaddleTensor& output) { - std::string line; - std::ifstream file(referfile); - std::getline(file, line); - auto refer = ProcessALine(line); - file.close(); - - size_t numel = output.data.length() / PaddleDtypeSize(output.dtype); - VLOG(3) << "predictor output numel " << numel; - VLOG(3) << "reference output numel " << refer.data.size(); - PADDLE_ENFORCE_EQ(numel, refer.data.size()); - switch (output.dtype) { - case PaddleDType::INT64: { - for (size_t i = 0; i < numel; ++i) { - PADDLE_ENFORCE_EQ(static_cast(output.data.data())[i], - refer.data[i]); - } - break; - } - case PaddleDType::FLOAT32: - for (size_t i = 0; i < numel; ++i) { - PADDLE_ENFORCE_LT( - fabs(static_cast(output.data.data())[i] - refer.data[i]), - 1e-5); - } - break; - } -} - /* * Use the native fluid engine to inference the demo. */ void Main(bool use_gpu) { + std::unique_ptr predictor; NativeConfig config; config.param_file = FLAGS_modeldir + "/__params__"; config.prog_file = FLAGS_modeldir + "/__model__"; @@ -111,7 +49,7 @@ void Main(bool use_gpu) { } VLOG(3) << "init predictor"; - auto predictor = + predictor = CreatePaddlePredictor(config); VLOG(3) << "begin to process data"; @@ -131,7 +69,7 @@ void Main(bool use_gpu) { VLOG(3) << "run executor"; std::vector output; - predictor->Run({input}, &output); + predictor->Run({input}, &output, 1); VLOG(3) << "output.size " << output.size(); auto& tensor = output.front(); @@ -146,9 +84,10 @@ void Main(bool use_gpu) { int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); - paddle::demo::Main(false /* use_gpu*/); if (FLAGS_use_gpu) { paddle::demo::Main(true /*use_gpu*/); + } else { + paddle::demo::Main(false /*use_gpu*/); } return 0; } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 70f9e397c96cf3fe92779778950f3df71b5a67c9..c3dd1f433691e1c96e9f38ef7b595befad26408f 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -70,6 +70,14 @@ if (NOT EXISTS ${OCR_INSTALL_DIR}) endif() inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) +# resnet50 +set(RESNET50_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/resnet50") +if (NOT EXISTS ${RESNET50_INSTALL_DIR}) + inference_download_and_uncompress(${RESNET50_INSTALL_DIR} ${INFERENCE_URL} "resnet50_model.tar.gz") +endif() +inference_analysis_test(test_analyzer_resnet50 SRCS analyzer_resnet50_tester.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${RESNET50_INSTALL_DIR}/model) + # anakin if (WITH_ANAKIN AND WITH_MKL) # only needed in CI # anakin rnn1 diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..290fb007d8ba94a2d121947fe67c6474586ac0e0 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void SetConfig(AnalysisConfig *cfg) { + cfg->param_file = FLAGS_infer_model + "/params"; + cfg->prog_file = FLAGS_infer_model + "/model"; + cfg->use_gpu = false; + cfg->device = 0; + cfg->enable_ir_optim = true; + cfg->specify_input_name = true; +} + +void SetInput(std::vector> *inputs) { + PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); + + PaddleTensor input; + // channel=3, height/width=318 + std::vector shape({FLAGS_batch_size, 3, 318, 318}); + input.shape = shape; + input.dtype = PaddleDType::FLOAT32; + + // fill input data, for profile easily, do not use random data here. + size_t size = FLAGS_batch_size * 3 * 318 * 318; + input.data.Resize(size * sizeof(float)); + float *input_data = static_cast(input.data.data()); + for (size_t i = 0; i < size; i++) { + *(input_data + i) = static_cast(i) / size; + } + + std::vector input_slots; + input_slots.assign({input}); + (*inputs).emplace_back(input_slots); +} + +// Easy for profiling independently. +TEST(Analyzer_resnet50, profile) { + AnalysisConfig cfg; + SetConfig(&cfg); + std::vector outputs; + + std::vector> input_slots_all; + SetInput(&input_slots_all); + TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + + if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { + PADDLE_ENFORCE_EQ(outputs.size(), 1UL); + size_t size = GetSize(outputs[0]); + // output is a 512-dimension feature + EXPECT_EQ(size, 512 * FLAGS_batch_size); + } +} + +// Check the fuse status +TEST(Analyzer_resnet50, fuse_statis) { + AnalysisConfig cfg; + SetConfig(&cfg); + int num_ops; + auto predictor = CreatePaddlePredictor(cfg); + auto fuse_statis = GetFuseStatis( + static_cast(predictor.get()), &num_ops); + ASSERT_TRUE(fuse_statis.count("fc_fuse")); + EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); +} + +// Compare result of NativeConfig and AnalysisConfig +TEST(Analyzer_resnet50, compare) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + CompareNativeAndAnalysis(cfg, input_slots_all); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index 5a68b0b25db4230dfa666f7773f6c278b7ab2455..c76d72ccd99649913aefcb2aa57fe6061db8ca6d 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -270,10 +270,11 @@ TEST(Analyzer_rnn1, multi_thread) { std::vector> input_slots_all; SetInput(&input_slots_all); - TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); + TestPrediction(cfg, input_slots_all, &outputs, 4 /* multi_thread */); } -bool CompareTensors(framework::Scope &a_scope, framework::Scope &b_scope, +bool CompareTensors(const framework::Scope &a_scope, + const framework::Scope &b_scope, const std::vector &tensors) { for (auto &x : tensors) { auto *a_var = a_scope.FindVar(x); diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index a2e86305b85dd893f578e97e0105fec828916fb4..305b8bfe158150d5dfd8bdaee2c0a89afe264de4 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -61,8 +61,6 @@ void SetConfig(AnalysisConfig *cfg) { cfg->ir_passes.push_back("fc_gru_fuse_pass"); #ifdef PADDLE_WITH_MKLDNN cfg->_use_mkldnn = true; - // disable mkldnn fuse since it should have some bugs - cfg->ir_passes.push_back("conv_relu_mkldnn_fuse_pass"); #endif } diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index 017fc4cd7b11c150cb941fffca2606a4d707330f..977155440df5294216382cff1c67c2aaca1f546d 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -4,7 +4,6 @@ function(inference_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(inference_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(arg_list "") if(inference_test_ARGS) foreach(arg ${inference_test_ARGS}) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2ef13b72ed3ff6ae8ad8748ddea977e693615ac6..031109398d8e21ad95f19f65fdd814e1782889e6 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -230,7 +230,7 @@ if(WITH_DISTRIBUTE) op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) endforeach() - + #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op # listen_and_serv_op sum_op executor SERIAL) @@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() +op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index c091476d6d132db17a656d5c8dee65e3a88d9ac2..bbf52bea1358c32596ab6f14eeaa419735d19fc6 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -80,7 +80,7 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -91,12 +91,26 @@ class ActivationOp : public framework::OperatorWithKernel { } }; +class ActivationOpInferVarType : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc& op_desc, + framework::BlockDesc* block) const override { + auto x_name = op_desc.Input("X")[0]; + auto out_name = op_desc.Output("Out")[0]; + auto& x = block->FindRecursiveOrCreateVar(x_name); + auto& out = block->FindRecursiveOrCreateVar(out_name); + out.SetType(x.GetType()); + out.SetDataType(x.GetDataType()); + } +}; + class ActivationOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + ctx->ShareDim("Out", framework::GradVarName("X")); + ctx->ShareLoD("Out", framework::GradVarName("X")); } protected: @@ -525,12 +539,14 @@ namespace ops = paddle::operators; #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::operators::ActivationOpInferVarType, \ ::paddle::operators::OP_NAME##GradMaker); \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ ::paddle::operators::OP_NAME##OpMaker, \ + ::paddle::operators::ActivationOpInferVarType, \ ::paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index a2f5a2545701991263c1ef842e9275b1edbfd2ca..d25160f4232b5a621d16b9f469f56bd5aa7c88e3 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -42,8 +42,8 @@ class ArgsortOp : public framework::OperatorWithKernel { "-rank(Input(X)) (%d).", axis, num_dims); - ctx->SetOutputDim("Out", in_dims); - ctx->SetOutputDim("Indices", in_dims); + ctx->ShareDim("X", "Out"); + ctx->ShareDim("X", "Indices"); ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Indices"); } diff --git a/paddle/fluid/operators/clip_by_norm_op.h b/paddle/fluid/operators/clip_by_norm_op.h index 5af0eb0b2ada66d5ae7d521d80e213f9e61f826f..855c4d70677395992e2bf685c910cbea2d37b20b 100644 --- a/paddle/fluid/operators/clip_by_norm_op.h +++ b/paddle/fluid/operators/clip_by_norm_op.h @@ -16,12 +16,15 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +using SelectedRows = framework::SelectedRows; template using EigenVector = framework::EigenVector; @@ -31,9 +34,40 @@ class ClipByNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max_norm = context.Attr("max_norm"); - auto* input = context.Input("X"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); + auto in_var = context.InputVar("X"); + + Tensor* output = nullptr; + const Tensor* input = nullptr; + if (in_var->IsType()) { + input = context.Input("X"); + + output = context.Output("Out"); + output->mutable_data(context.GetPlace()); + } else if (in_var->IsType()) { + auto* x = context.Input("X"); + + // merge ids in selected rows first + math::scatter::MergeAdd merge_func; + SelectedRows* merged_input = + const_cast(context.scope()) + .Var() + ->GetMutable(); + merge_func(context.template device_context(), *x, + merged_input); + input = &(merged_input->value()); + + SelectedRows* output_selected_rows = context.Output("Out"); + output_selected_rows->set_rows(merged_input->rows()); + output_selected_rows->set_height(merged_input->height()); + output = output_selected_rows->mutable_value(); + output->Resize(merged_input->value().dims()); + output->mutable_data(context.GetPlace()); + } else { + PADDLE_THROW("Unexpected branch, input variable type is %s", + in_var->Type().name()); + } + + PADDLE_ENFORCE_NOT_NULL(input); auto x = EigenVector::Flatten(*input); auto out = EigenVector::Flatten(*output); diff --git a/paddle/fluid/operators/conv_shift_op.cc b/paddle/fluid/operators/conv_shift_op.cc index f2549e814d6f3b5674fe2eec1139f1c3dc6fa0b4..08506ddd18ed35831702814e70962cb36ec958b1 100644 --- a/paddle/fluid/operators/conv_shift_op.cc +++ b/paddle/fluid/operators/conv_shift_op.cc @@ -44,7 +44,7 @@ class ConvShiftOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], "The 2nd dimension of Input(Y) should be less than or " "equal to the 2nd dimension of Input(X)."); - ctx->SetOutputDim("Out", x_dims); + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/fluid/operators/cub_reduce.h b/paddle/fluid/operators/cub_reduce.h index 16fdad775f7befaac04b1ac59a601f04e0ab2bdc..afd3922b8d6537ee16dc5041a838858089adbdb1 100644 --- a/paddle/fluid/operators/cub_reduce.h +++ b/paddle/fluid/operators/cub_reduce.h @@ -22,6 +22,7 @@ #include // NOLINT #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" namespace paddle { namespace operators { @@ -293,7 +294,12 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y, } auto x_data = x.data(); auto y_data = y->mutable_data(x.place()); - if (reduce_num == 1) return; + if (reduce_num == 1) { + auto out_dims = y->dims(); + framework::TensorCopy(x, y->place(), y); + y->Resize(out_dims); + return; + } #define CUB_BLOCK_DIM_CASE(block_dim) \ case block_dim: { \ diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index 94df11bee70dec44f19ee9ffff04ca92d5990ee8..7e5975ead64ab39a9c618a33e300c4fce55a5b22 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -41,7 +41,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto y_dim = ctx->GetInputDim("Y"); PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), "Rank of first input must >= rank of second input."); - ctx->SetOutputDim("Out", x_dim); + + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -70,6 +71,7 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference { auto& x = block->FindRecursiveOrCreateVar(x_name); auto& out = block->FindRecursiveOrCreateVar(out_name); out.SetType(x.GetType()); + out.SetDataType(x.GetDataType()); } }; @@ -157,10 +159,12 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { auto x_grad_name = framework::GradVarName("X"); auto y_grad_name = framework::GradVarName("Y"); if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); + ctx->ShareDim("X", /*->*/ x_grad_name); + ctx->ShareLoD("X", /*->*/ x_grad_name); } if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); + ctx->ShareDim("Y", /*->*/ y_grad_name); + ctx->ShareLoD("Y", /*->*/ y_grad_name); } } @@ -193,14 +197,15 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { auto x_grad_name = framework::GradVarName("X"); if (ctx->HasOutput(x_grad_name)) { - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(x_grad_name, out_dims); + ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name); + ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name); } auto y_grad_name = framework::GradVarName("Y"); if (ctx->HasOutput(y_grad_name)) { PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); - auto y_dims = ctx->GetInputDim("Y"); - ctx->SetOutputDim(y_grad_name, y_dims); + + ctx->ShareDim("Y", /*->*/ y_grad_name); + ctx->ShareLoD("Y", /*->*/ y_grad_name); } } }; diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 2008e7027524ffd1f80a6eede015801b8a0b0254..5d6488c67e0db440c8d4609736523643dd666dcc 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -48,7 +48,8 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { "Input(X) of FakeDequantizeMaxAbsOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of FakeDequantizeMaxAbsOp should not be null."); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc index 0b917a403620e2ffb2cbb4ca7856cce9584e1eef..fdc9cb4888b3468b85abfa0c693ed8ac5b0d450b 100644 --- a/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused_embedding_fc_lstm_op.cc @@ -93,11 +93,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ctx->SetOutputDim("Cell", out_dims); ctx->ShareLoD("Ids", "Hidden"); ctx->ShareLoD("Ids", "Cell"); - int xx_width; - if (ctx->Attrs().Get("use_seq")) { - xx_width = wh_dims[1]; - } else { - xx_width = x_dims[1] > wh_dims[1] ? wh_dims[1] : x_dims[1]; + if (!ctx->Attrs().Get("use_seq")) { PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), "Assert only one Output(BatchedInput) of LSTM."); PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), @@ -112,7 +108,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedCell", out_dims); } - ctx->SetOutputDim("XX", {x_dims[0], xx_width}); + ctx->SetOutputDim("XX", {x_dims[0], wh_dims[1]}); ctx->ShareLoD("Ids", "XX"); } @@ -435,8 +431,6 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel { INIT_VEC_FUNC INIT_BASE_INPUT_DATAS - // std::cout << "===> Batch Compute" << std::endl; - auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); auto* batched_input = ctx.Output("BatchedInput"); diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..248c7793560db99c0af06421bf74808422016061 --- /dev/null +++ b/paddle/fluid/operators/isfinite_op.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/isfinite_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class OverflowOp : public framework::OperatorWithKernel { + public: + OverflowOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of OverflowOp should not be null."); + + ctx->SetOutputDim("Out", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + int dtype = -1; + auto *x_var = ctx.InputVar("X"); + if (x_var->IsType()) { + dtype = framework::ToDataType(x_var->Get().type()); + } else if (x_var->IsType()) { + dtype = framework::ToDataType( + x_var->Get().value().type()); + } else { + PADDLE_THROW("Cannot find the input data type by all input data"); + } + return framework::OpKernelType(framework::proto::VarType::Type(dtype), + ctx.GetPlace()); + } +}; + +class OverflowOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input tensors of overflow operator."); + AddOutput("Out", + "(Tensor) 1-dim tensor, contains a bool scalar. The output " + "tensor of overflow operator."); + AddComment(string::Sprintf(R"DOC( +Overflow operator. + +$$Out = any(X)$$ + +If any X contains Inf or Nan, the Out will generate a indicator. +Out = Inf if any X contains Inf, +Out = Nan if any X contains Nan, +Out = 0 if no Inf/Nan detected. +If X contains both Inf/Nan, it will return the first indicator it meeted. +)DOC", + GetName(), GetComments())); + } + + protected: + virtual std::string GetName() const = 0; + virtual std::string GetComments() const = 0; +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +#define REGISTER_OP_MAKER(op_type, comment) \ + namespace paddle { \ + namespace operators { \ + class _##op_type##OverflowOpMaker \ + : public ::paddle::operators::OverflowOpMaker { \ + protected: \ + std::string GetName() const { return #op_type; } \ + std::string GetComments() const { return comment; } \ + }; \ + } \ + } \ + REGISTER_OPERATOR(op_type, ops::OverflowOp, \ + ops::_##op_type##OverflowOpMaker, \ + paddle::framework::EmptyGradOpMaker) + +#define REGISTER_OVERFLOW_CPU_KERNEL(op_type, functor) \ + REGISTER_OP_CPU_KERNEL( \ + op_type, ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel); + +REGISTER_OP_MAKER(isinf, "isinf(X)"); +REGISTER_OP_MAKER(isnan, "isnan(X)"); +REGISTER_OP_MAKER(isfinite, "isfinite(X)"); +FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CPU_KERNEL); diff --git a/paddle/fluid/operators/isfinite_op.cu b/paddle/fluid/operators/isfinite_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8d1268b18c6fec03063051f545075209a6fcde27 --- /dev/null +++ b/paddle/fluid/operators/isfinite_op.cu @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define EIGEN_USE_GPU +#include "paddle/fluid/operators/isfinite_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \ + REGISTER_OP_CUDA_KERNEL( \ + op_type, ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel, \ + ops::OverflowKernel); + +FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CUDA_KERNEL); diff --git a/paddle/fluid/operators/isfinite_op.h b/paddle/fluid/operators/isfinite_op.h new file mode 100644 index 0000000000000000000000000000000000000000..83b080856366ac3332c5856a19b721893bb80eb3 --- /dev/null +++ b/paddle/fluid/operators/isfinite_op.h @@ -0,0 +1,71 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +struct InfinityFunctor { + void operator()(const framework::Tensor& tensor, framework::Tensor* out) { + framework::TensorContainsInf(tensor, out); + } +}; + +struct NANFunctor { + void operator()(const framework::Tensor& tensor, framework::Tensor* out) { + framework::TensorContainsNAN(tensor, out); + } +}; + +struct IsfiniteFunctor { + void operator()(const framework::Tensor& tensor, framework::Tensor* out) { + framework::TensorIsfinite(tensor, out); + } +}; + +template +class OverflowKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto* x = ctx.InputVar("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + Functor functor; + if (x->IsType()) { + auto* in = ctx.Input("X"); + functor(*in, out); + } else if (x->IsType()) { + auto& in = ctx.Input("X")->value(); + functor(in, out); + } else { + PADDLE_THROW("Unsupported input type."); + } + } +}; + +} // namespace operators +} // namespace paddle + +#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ + __macro(isinf, InfinityFunctor); \ + __macro(isnan, NANFunctor); \ + __macro(isfinite, IsfiniteFunctor); diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index d77b095c5d783a2a9fab87eb8b458117a6a3d225..b9ac54e446811889b647397ae1fbb11c28f46777 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -137,6 +137,7 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { << " is set to LoDTensor"; block->Var(out_var_name)->SetType(framework::proto::VarType::LOD_TENSOR); } + block->Var(out_var_name)->SetDataType(block->Var("W")->GetDataType()); } }; diff --git a/paddle/fluid/operators/math/depthwise_conv.cu b/paddle/fluid/operators/math/depthwise_conv.cu index 3be389912307f7aac6dda6d1018943eb8f08696d..66d37c3bf31ffa420cc527cb576dcdc5505a0960 100644 --- a/paddle/fluid/operators/math/depthwise_conv.cu +++ b/paddle/fluid/operators/math/depthwise_conv.cu @@ -46,17 +46,20 @@ __forceinline__ __device__ unsigned warp_id() { return ret; } +#define ARG_DEFINE_KernelDepthwiseConv \ + const T *const input_data, const T *const filter_data, const int batch_size, \ + const int output_channels, const int output_height, \ + const int output_width, const int input_channels, \ + const int input_height, const int input_width, \ + const int filter_multiplier, const int filter_height, \ + const int filter_width, const int stride_height, const int stride_width, \ + const int padding_height, const int padding_width, \ + const int dilate_height, const int dilate_width, T *const output_data + // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. template -__device__ __inline__ void KernelDepthwiseConv( - const T* const input_data, const T* const filter_data, const int batch_size, - const int output_channels, const int output_height, const int output_width, - const int input_channels, const int input_height, const int input_width, - const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const output_data) { +__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { const int batch = blockIdx.y; @@ -97,42 +100,105 @@ __device__ __inline__ void KernelDepthwiseConv( } } -template -__global__ void KernelDepthwiseConvSp( - const T* const input_data, const T* const filter_data, const int batch_size, - const int output_channels, const int output_height, const int output_width, - const int input_channels, const int input_height, const int input_width, - const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const output_data) { - if (c_filter_multiplier == 0) - KernelDepthwiseConv(input_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, - input_height, input_width, filter_multiplier, - filter_height, filter_width, stride_height, - stride_width, padding_height, padding_width, - dilate_height, dilate_width, output_data); +template +__device__ __inline__ void KernelDepthwiseConvCFilter( + ARG_DEFINE_KernelDepthwiseConv) { + const int kWeghtSize = c_filter * c_filter; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + const T* weight = filter_data + c_out * c_filter * c_filter; + for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i]; - else - KernelDepthwiseConv(input_data, filter_data, batch_size, output_channels, - output_height, output_width, input_channels, - input_height, input_width, c_filter_multiplier, - filter_height, filter_height, c_stride, c_stride, - padding_height, padding_width, dilate_height, - dilate_width, output_data); + for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { + for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + + const int c_in = c_out / filter_multiplier; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + c_filter * dilate_height; + const int w_in_end = w_in_start + c_filter * dilate_width; + + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_in_start, h_f = 0; h_f < c_filter; + h_in += dilate_height, h_f++) { + for (int w_in = w_in_start, w_f = 0; w_f < c_filter; + w_in += dilate_width, w_f++) { + if (h_in >= 0 && h_in < input_height && w_in >= 0 && + w_in < input_width) { + const int offset = in_offset + h_in * input_width + w_in; + value += r_weight[h_f * c_filter + w_f] * input_data[offset]; + } + } + } + int index = + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + + w_out; + output_data[index] = value; + } + } +} + +template +__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { + if (c_filter_multiplier == 0) { + if (c_filter == -1) + KernelDepthwiseConv( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, output_data); + else + KernelDepthwiseConvCFilter( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + filter_multiplier, filter_height, filter_width, stride_height, + stride_width, padding_height, padding_width, dilate_height, + dilate_width, output_data); + } else { + if (c_filter == -1) + KernelDepthwiseConv(input_data, filter_data, batch_size, + output_channels, output_height, output_width, + input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_height, + c_stride, c_stride, padding_height, padding_width, + dilate_height, dilate_width, output_data); + else + KernelDepthwiseConvCFilter( + input_data, filter_data, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, + padding_height, padding_width, dilate_height, dilate_width, + output_data); + } } // CUDA kernel to compute the depthwise convolution backprop w.r.t input. +#define ARG_DEFINE_KernelDepthwiseConvInputGrad \ + const T *const output_grad_data, const T *const filter_data, \ + const int batch_size, const int output_channels, \ + const int output_height, const int output_width, \ + const int input_channels, const int input_height, const int input_width, \ + const int filter_multiplier, const int filter_height, \ + const int filter_width, const int stride_height, const int stride_width, \ + const int padding_height, const int padding_width, \ + const int dilate_height, const int dilate_width, \ + T *const input_grad_data + template __device__ __inline__ void KernelDepthwiseConvInputGrad( - const T* const output_grad_data, const T* const filter_data, - const int batch_size, const int output_channels, const int output_height, - const int output_width, const int input_channels, const int input_height, - const int input_width, const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const input_grad_data) { + ARG_DEFINE_KernelDepthwiseConvInputGrad) { for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { const int batch = blockIdx.y; @@ -184,15 +250,67 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( } } -template +template +__device__ __inline__ void KernelDepthwiseConvInputGradCFilter( + ARG_DEFINE_KernelDepthwiseConvInputGrad) { + const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + for (int c_i = 0; c_i < filter_multiplier; c_i++) { + int c_out = c_in * filter_multiplier + c_i; + const T* weight = filter_data + c_out * c_filter * c_filter; + for (int i = 0; i < c_filter * c_filter; i++) + r_weight[i + c_i * c_filter * c_filter] = + weight[c_filter * c_filter - i - 1]; + } + + for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { + for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height; + + int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; + + T value = 0; + + for (int c_i = 0; c_i < filter_multiplier; c_i++) { + int c_out = c_in * filter_multiplier + c_i; + for (int h_out = h_out_start, h_f = 0; h_f < c_filter; + h_out += dilate_height, h_f++) { + for (int w_out = w_out_start, w_f = 0; w_f < c_filter; + w_out += dilate_width, w_f++) { + int s_h_out = h_out / stride_height; + int s_w_out = w_out / stride_width; + if (h_out % stride_height == 0 && w_out % stride_width == 0 && + s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && + s_w_out < output_width) { + const int output_grad_offset = + ((batch * output_channels + c_out) * output_height + + s_h_out) * + output_width + + s_w_out; + value += + output_grad_data[output_grad_offset] * + r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter]; + } + } + } + } + int index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; + input_grad_data[index] = value; + } + } +} + +template __global__ void KernelDepthwiseConvInputGradSp( - const T* const output_grad_data, const T* const filter_data, - const int batch_size, const int output_channels, const int output_height, - const int output_width, const int input_channels, const int input_height, - const int input_width, const int filter_multiplier, const int filter_height, - const int filter_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width, const int dilate_height, - const int dilate_width, T* const input_grad_data) { + ARG_DEFINE_KernelDepthwiseConvInputGrad) { if (c_filter_multiplier == 0) KernelDepthwiseConvInputGrad( output_grad_data, filter_data, batch_size, output_channels, @@ -200,13 +318,20 @@ __global__ void KernelDepthwiseConvInputGradSp( filter_multiplier, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, dilate_height, dilate_width, input_grad_data); - else + else if (c_filter == -1) KernelDepthwiseConvInputGrad( output_grad_data, filter_data, batch_size, output_channels, output_height, output_width, input_channels, input_height, input_width, c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, padding_height, padding_width, dilate_height, dilate_width, input_grad_data); + else + KernelDepthwiseConvInputGradCFilter( + output_grad_data, filter_data, batch_size, output_channels, + output_height, output_width, input_channels, input_height, input_width, + c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, + padding_height, padding_width, dilate_height, dilate_width, + input_grad_data); } // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. @@ -325,12 +450,14 @@ class DepthwiseConvFunctor { dim3 threads(std::min(output_width, thread), blocks, 1); dim3 grid(output_channels, batch_size, 1); int filter_multiplier = output_channels / input_channels; -#define check_case(c_filter_multiplier, c_stride) \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ if (c_filter_multiplier == 0 || \ filter_multiplier == c_filter_multiplier && \ - stride_height == stride_width && stride_height == c_stride) { \ - KernelDepthwiseConvSp<<>>( \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ + KernelDepthwiseConvSp<<>>( \ input_data, filter_data, batch_size, output_channels, output_height, \ output_width, input_channels, input_height, input_width, \ filter_multiplier, ksize_height, ksize_width, stride_height, \ @@ -338,11 +465,17 @@ class DepthwiseConvFunctor { dilate_width, output_data); \ return; \ } - check_case(1, 1); - check_case(1, 2); - // NOTE(liangdun): 0,0 for other case - // add other case if needed, e.g. check_case(2^n,1) - check_case(0, 0); + check_case(1, 1, 3); + check_case(1, 1, 5); + check_case(1, 1, -1); + check_case(1, 2, 3); + check_case(1, 2, 5); + check_case(1, 2, -1); + check_case(0, 0, 3); + check_case(0, 0, 5); + check_case(0, 0, -1); +// NOTE(liangdun): 0,0 for other case +// add other case if needed, e.g. check_case(2^n,1) #undef check_case } }; @@ -384,13 +517,15 @@ class DepthwiseConvInputGradFunctor { dim3 grid(input_channels, batch_size, 1); int filter_multiplier = output_channels / input_channels; -#define check_case(c_filter_multiplier, c_stride) \ +#define check_case(c_filter_multiplier, c_stride, c_filter) \ if (c_filter_multiplier == 0 || \ filter_multiplier == c_filter_multiplier && \ - stride_height == stride_width && stride_height == c_stride) { \ + stride_height == stride_width && stride_height == c_stride && \ + (ksize_height == ksize_width && ksize_height == c_filter || \ + c_filter == -1)) { \ KernelDepthwiseConvInputGradSp< \ - T, c_filter_multiplier, \ - c_stride><<>>( \ + T, c_filter_multiplier, c_stride, \ + c_filter><<>>( \ output_grad_data, filter_data, batch_size, output_channels, \ output_height, output_width, input_channels, input_height, \ input_width, filter_multiplier, ksize_height, ksize_width, \ @@ -398,11 +533,21 @@ class DepthwiseConvInputGradFunctor { dilate_height, dilate_width, input_grad_data); \ return; \ } - check_case(1, 1); - check_case(1, 2); - // NOTE(liangdun): 0,0 for other case - // add other case if needed, e.g. check_case(2^n,1) - check_case(0, 0); + check_case(1, 1, 3); + check_case(1, 1, 5); + check_case(1, 1, -1); + check_case(1, 2, 3); + check_case(1, 2, 5); + check_case(1, 2, -1); + check_case(2, 1, 3); + check_case(2, 1, 5); + check_case(2, 1, -1); + check_case(2, 2, 3); + check_case(2, 2, 5); + check_case(2, 2, -1); + check_case(0, 0, -1); +// NOTE(liangdun): 0,0 for other case +// add other case if needed, e.g. check_case(2^n,1) #undef check_case } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index e0c4c81bdd5b5d0af3bafe632a2fa033efd08050..58cfbb76e93a1c15c9b7cf9f9e596066c29b7ebb 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -49,7 +49,7 @@ class PReluOp : public framework::OperatorWithKernel { } else { PADDLE_THROW("Unkown mode %s", mode); } - ctx->SetOutputDim("Out", x_dim); + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index 13df1d4b4bb6c240610f96ccc8f223fc984d63f7..0fb7776fd9dbf437673820c7cf9411644272626c 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -54,7 +54,7 @@ class RNNMemoryHelperOpShapeInference : public framework::InferShapeBase { "Input(X) of rnn_memory_helper op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output of rnn_memory_helper op should not be null."); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/fluid/operators/sequence_conv_op.cc b/paddle/fluid/operators/sequence_conv_op.cc index ec6cb24350ae276724aae339590d40be1e9ea400..95a21a5d3ee6d8037431083edc25d1cddf05dedb 100644 --- a/paddle/fluid/operators/sequence_conv_op.cc +++ b/paddle/fluid/operators/sequence_conv_op.cc @@ -90,8 +90,8 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ctx->GetInputDim("PaddingData")); } if (ctx->HasOutput(framework::GradVarName("X"))) { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", framework::GradVarName("X")); + ctx->ShareDim("X", /*->*/ framework::GradVarName("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } if (ctx->HasOutput(framework::GradVarName("Filter"))) { ctx->SetOutputDim(framework::GradVarName("Filter"), diff --git a/paddle/fluid/operators/sequence_pool_op.cc b/paddle/fluid/operators/sequence_pool_op.cc index 5c6fd13d42e43e3502a1cab85a56e019420c708d..15d3f064eb7b025dc9a85b2aabad24186061cbd4 100644 --- a/paddle/fluid/operators/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_pool_op.cc @@ -102,8 +102,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { for (int64_t i = 1; i < og_dims.size(); ++i) { PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch."); } - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - ctx->ShareLoD("X", framework::GradVarName("X")); + + ctx->ShareDim("X", /*->*/ framework::GradVarName("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } protected: diff --git a/paddle/fluid/operators/sequence_reshape_op.cc b/paddle/fluid/operators/sequence_reshape_op.cc index ef5e6f3210234d59298fcf04c812390643c693d0..31d28d723498892f287246ba228df757d5b9f6c8 100644 --- a/paddle/fluid/operators/sequence_reshape_op.cc +++ b/paddle/fluid/operators/sequence_reshape_op.cc @@ -92,7 +92,7 @@ class SequenceReshapeGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequenceReshapeGradOp should not be null."); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareDim("X", /*->*/ framework::GradVarName("X")); ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); } }; diff --git a/paddle/fluid/operators/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_softmax_op.cc index c44f8206eb5079fef969e3e527552512eebd0f1a..ada3e0c8dbba38729c2b9c8b02335327835f2ef4 100644 --- a/paddle/fluid/operators/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_softmax_op.cc @@ -27,7 +27,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { "Input(X) of SequenceSoftmaxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SequenceSoftmaxOp should not be null."); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } diff --git a/paddle/fluid/operators/shrink_rnn_memory_op.cc b/paddle/fluid/operators/shrink_rnn_memory_op.cc index 29d2fb989754f5621222768a279a1c898ea1c355..e1c74c3a2f89235ba92c396d1a548271bb7d939d 100644 --- a/paddle/fluid/operators/shrink_rnn_memory_op.cc +++ b/paddle/fluid/operators/shrink_rnn_memory_op.cc @@ -151,9 +151,9 @@ class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase { void operator()(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInput("X")); PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X"))); - context->SetOutputDim(framework::GradVarName("X"), - context->GetInputDim("X")); - context->ShareLoD("X", framework::GradVarName("X")); + + context->ShareDim("X", /*->*/ framework::GradVarName("X")); + context->ShareLoD("X", /*->*/ framework::GradVarName("X")); } }; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index c3b0fe32098cb4b41ccc155db58809ef9f1bf46b..193de05422bb78572c0e5eaf4cd46744c3bcb113 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -40,7 +40,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { "The 2nd dimension of Input(X) and Input(Label) should " "be equal."); - ctx->SetOutputDim("Out", x_dims); + ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } }; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2d7ea66ad434c42b54522432095ce29..8e4a07556fb51dbb15ef948fcee120e2f68e089a 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, * 3. go to the second setp, until one thread's topk value is null; * 4. go to the first setp, until get the topk value. */ + template __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, - const T* src, int lds, int dim, int k) { + const T* src, int lds, int dim, int k, + int grid_dim, int num) { __shared__ Pair sh_topk[BlockSize]; __shared__ int maxid[BlockSize / 2]; const int tid = threadIdx.x; const int warp = threadIdx.x / 32; - output += blockIdx.x * output_stride; - indices += blockIdx.x * k; - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; + const int bid = blockIdx.x; + for (int i = bid; i < num; i += grid_dim) { + output += i * output_stride; + indices += i * k; + + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + + for (int k = 0; k < MaxLength; k++) { + topk[k].set(-INFINITY, -1); + } + while (k) { + ThreadGetTopK( + topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); - for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + sh_topk[tid] = topk[0]; + BlockReduce(sh_topk, maxid, topk, &output, + &indices, &beam, &k, tid, warp); + } } - while (k) { - ThreadGetTopK(topk, &beam, k, - src + blockIdx.x * lds, &firststep, - &is_empty, &max, dim, tid); - - sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, maxid, topk, &output, - &indices, &beam, &k, tid, warp); +} + +inline static int GetDesiredBlockDim(int dim) { + if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; } } +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + template class TopkOpCUDAKernel : public framework::OpKernel { public: @@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel { // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. // TODO(typhoonzero): refine this kernel. - dim3 threads(256, 1); - dim3 grid(input_height, 1); - - KeMatrixTopK<<< - grid, threads, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, static_cast(k)); + const int kMaxHeight = 2048; + int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; + auto& dev_ctx = ctx.cuda_device_context(); + + switch (GetDesiredBlockDim(input_width)) { + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + output_data, output->dims()[1], indices_data, input_data, + input_width, input_width, static_cast(k), gridx, + input_height)); + default: + PADDLE_THROW("Error"); + } } }; +#undef FIXED_BLOCK_DIM_BASE +#undef FIXED_BLOCK_DIM + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index d854e2803975543b51c50ea2bc173322d3c3ca5e..1e8708f2648d7dd3c10319bd0a4be193d2458d53 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -148,7 +148,7 @@ struct TruncatedNormal { T operator()(T value) const { auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return (std::sqrt(2.0) * Erfinv(2 * p - 1) + mean) * std; + return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean; } }; diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cu b/paddle/fluid/operators/truncated_gaussian_random_op.cu index ad2a9021bfe344d838dff2040b3fb9371274e218..5a3510babe4d57b9e80f0e7898df98033834ca15 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cu +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cu @@ -42,7 +42,7 @@ struct TruncatedNormal { rng.discard(n); T value = dist(rng); auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value; - return (std::sqrt(2.0) * erfinvf(2 * p - 1) + mean) * std; + return std::sqrt(2.0) * erfinvf(2 * p - 1) * std + mean; } }; @@ -52,6 +52,7 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); + unsigned int seed = static_cast(context.Attr("seed")); if (seed == 0) { std::random_device rd; diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 763bb403588d13c15271d26b09813dddf3a5dd8c..aa907595cb7cf165974caa69fe8eb0370471732d 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -23,14 +23,14 @@ namespace operators { template class CPUUniformRandomKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - framework::Tensor* tensor = nullptr; + void Compute(const framework::ExecutionContext &ctx) const override { + framework::Tensor *tensor = nullptr; auto out_var = ctx.OutputVar("Out"); if (out_var->IsType()) { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { auto shape = ctx.Attr>("shape"); - auto* selected_rows = out_var->GetMutable(); + auto *selected_rows = out_var->GetMutable(); tensor = selected_rows->mutable_value(); tensor->Resize(framework::make_ddim(shape)); selected_rows->mutable_rows()->reserve(shape[0]); @@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { "uniform_random_op's output only" "supports SelectedRows and LoDTensor"); } - T* data = tensor->mutable_data(ctx.GetPlace()); + T *data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { @@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UniformRandomOp should not be null."); PADDLE_ENFORCE( ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), "uniform_random's min must less then max"); - auto& shape = ctx->Attrs().Get>("shape"); + auto &shape = ctx->Attrs().Get>("shape"); std::vector temp; temp.reserve(shape.size()); for (auto dim : shape) { @@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); @@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max]. class UniformRandomOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { auto out_var_name = op_desc.Output("Out").front(); - if (block->FindRecursiveOrCreateVar(out_var_name).GetType() == - framework::proto::VarType::SELECTED_ROWS) { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(framework::proto::VarType::SELECTED_ROWS); - } else { - block->FindRecursiveOrCreateVar(out_var_name) - .SetType(framework::proto::VarType::LOD_TENSOR); + auto var_data_type = static_cast( + boost::get(op_desc.GetAttr("dtype"))); + + auto out_var = block->FindRecursiveOrCreateVar(out_var_name); + if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) { + out_var.SetType(framework::proto::VarType::LOD_TENSOR); } + out_var.SetDataType(var_data_type); } }; diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 16eac1ec2406c147fa765bc014038ae03a1416b2..3c8a01b6e47459760b05b5ca7fa4fa5e1d37d112 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -224,10 +224,12 @@ class WhileGradOp : public framework::OperatorBase { if (cur_scope_iter == step_scopes->rbegin()) { auto *var = (*cur_scope_iter)->FindVar(inside_grad_name); PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name); - PADDLE_ENFORCE(var->IsType() || - var->IsType(), - "Currently the type of var only can be LoDTensorArray " - "or LoDTensor."); + PADDLE_ENFORCE( + var->IsType() || + var->IsType(), + "Currently the type of var only can be LoDTensorArray, " + "or LoDTensor, but the received var[%s] is %s.", + inside_grad_name, var->Type().name()); if (var->IsType()) { auto &inside_tensor = var->Get(); diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 126636d879213b1c8f242db8fbdf6a358a1d2da9..f599e7fbc886a60394ae4690e4160275b55b8596 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -20,8 +20,11 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" DEFINE_double(fraction_of_gpu_memory_to_use, 0.92, - "Default use 92% of GPU memory for PaddlePaddle," - "reserve the rest for page tables, etc"); + "Allocate a trunk of gpu memory that is this fraction of the " + "total gpu memory size. Future memory usage will be allocated " + "from the trunk. If the trunk doesn't have enough gpu memory, " + "additional trunks of the same size will be requested from gpu " + "until the gpu has no memory left for another trunk."); namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 652a6ec7a4e2e823b28f39b449570cd375e88e18..612f3bc0e7940663f84a55b2c4395a7b5119d5bb 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -276,7 +276,7 @@ struct EventItem { // Print results void PrintProfiler(const std::vector>& events_table, const std::string& sorted_domain, const size_t name_width, - const size_t data_width, double total) { + const size_t data_width, bool merge_thread) { // Output header information std::cout << "\n------------------------->" << " Profiling Report " @@ -292,6 +292,10 @@ void PrintProfiler(const std::vector>& events_table, PADDLE_THROW("Invalid profiler state", g_state); } + if (merge_thread) { + std::cout << "Note! This Report merge all thread info into one." + << std::endl; + } std::cout << "Place: " << place << std::endl; std::cout << "Time unit: ms" << std::endl; std::cout << "Sorted by " << sorted_domain @@ -312,8 +316,7 @@ void PrintProfiler(const std::vector>& events_table, << std::setw(data_width) << event_item.min_time << std::setw(data_width) << event_item.max_time << std::setw(data_width) << event_item.ave_time - << std::setw(data_width) << event_item.total_time / total - << std::endl; + << std::setw(data_width) << event_item.ratio << std::endl; } } std::cout << std::endl; @@ -321,8 +324,10 @@ void PrintProfiler(const std::vector>& events_table, // Parse the event list and output the profiling report void ParseEvents(const std::vector>& events, + bool merge_thread, EventSortingKey sorted_by = EventSortingKey::kDefault) { if (g_state == ProfilerState::kDisabled) return; + if (merge_thread && events.size() < 2) return; std::string sorted_domain; std::function sorted_func; @@ -361,34 +366,55 @@ void ParseEvents(const std::vector>& events, sorted_domain = "event first end time"; } + const std::vector>* analyze_events; + std::vector> merged_events_list; + if (merge_thread) { + std::vector merged_events; + for (int i = 0; i < events.size(); ++i) { + for (int j = 0; j < events[i].size(); ++j) { + merged_events.push_back(events[i][j]); + } + } + merged_events_list.push_back(merged_events); + analyze_events = &merged_events_list; + } else { + analyze_events = &events; + } + std::vector> events_table; size_t max_name_width = 0; - double total = 0.; // the total time - for (size_t i = 0; i < events.size(); i++) { + for (size_t i = 0; i < (*analyze_events).size(); i++) { + double total = 0.; // the total time in one thread std::list pushed_events; std::vector event_items; std::unordered_map event_idx; - for (size_t j = 0; j < events[i].size(); j++) { - if (events[i][j].type() == EventType::kPushRange) { - pushed_events.push_back(events[i][j]); - } else if (events[i][j].type() == EventType::kPopRange) { + for (size_t j = 0; j < (*analyze_events)[i].size(); j++) { + if ((*analyze_events)[i][j].type() == EventType::kPushRange) { + pushed_events.push_back((*analyze_events)[i][j]); + } else if ((*analyze_events)[i][j].type() == EventType::kPopRange) { std::list::reverse_iterator rit = pushed_events.rbegin(); while (rit != pushed_events.rend() && - rit->name() != events[i][j].name()) { + rit->name() != (*analyze_events)[i][j].name()) { ++rit; } if (rit != pushed_events.rend()) { double event_time = (g_state == ProfilerState::kCUDA || g_state == ProfilerState::kAll) - ? rit->CudaElapsedMs(events[i][j]) - : rit->CpuElapsedMs(events[i][j]); + ? rit->CudaElapsedMs((*analyze_events)[i][j]) + : rit->CpuElapsedMs((*analyze_events)[i][j]); total += event_time; - std::string event_name = - "thread" + std::to_string(rit->thread_id()) + "::" + rit->name(); - max_name_width = std::max(max_name_width, event_name.size()); + std::string event_name; + if (merge_thread) { + event_name = rit->name(); + max_name_width = std::max(max_name_width, event_name.size()); + } else { + event_name = "thread" + std::to_string(rit->thread_id()) + "::" + + rit->name(); + max_name_width = std::max(max_name_width, event_name.size()); + } if (event_idx.find(event_name) == event_idx.end()) { event_idx[event_name] = event_items.size(); @@ -413,7 +439,7 @@ void ParseEvents(const std::vector>& events, pushed_events.erase((++rit).base()); } else { LOG(WARNING) << "Cannot find the push marker of event \'" - << events[i][j].name() + << (*analyze_events)[i][j].name() << "\', which will be ignored in profiling report."; } } @@ -421,6 +447,7 @@ void ParseEvents(const std::vector>& events, // average time for (auto& item : event_items) { item.ave_time = item.total_time / item.calls; + item.ratio = item.total_time / total; } // sort if (sorted_by != EventSortingKey::kDefault) { @@ -438,7 +465,8 @@ void ParseEvents(const std::vector>& events, } // Print report - PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12, total); + PrintProfiler(events_table, sorted_domain, max_name_width + 4, 12, + merge_thread); } void DisableProfiler(EventSortingKey sorted_key, @@ -449,7 +477,8 @@ void DisableProfiler(EventSortingKey sorted_key, Mark("_stop_profiler_", nullptr); std::vector> all_events = GetAllEvents(); - ParseEvents(all_events, sorted_key); + ParseEvents(all_events, true, sorted_key); + ParseEvents(all_events, false, sorted_key); ResetProfiler(); DeviceTracer* tracer = GetDeviceTracer(); if (tracer->IsEnabled()) { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 295af1c5837d70c32b522cc47c8c3e12d5bd61c7..a91894ba8935076748aa3a8ded8d8829b88ebb33 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -157,7 +157,50 @@ PYBIND11_PLUGIN(core) { .def("_get_double_element", TensorGetElement) .def("_dtype", [](Tensor &self) { return ToDataType(self.type()); }); - py::class_(m, "LoDTensor") + py::class_(m, "LoDTensor", R"DOC( + LoDTensor is a Tensor with optional LoD information. + + np.array(lod_tensor) can convert LoDTensor to numpy array. + lod_tensor.lod() can retrieve the LoD information. + + LoD is short for Level of Details and is usually used for varied sequence + length. You can skip the following comment if you don't need optional LoD. + + For example: + A LoDTensor X can look like the example below. It contains 2 sequences. + The first has length 2 and the second has length 3, as described by x.lod. + + The first tensor dimension 6=2+3 is calculated from LoD if it's available. + It means the total number of sequence element. In X, each element has 2 + columns, hence [6, 2]. + + x.lod = [[2, 3]] + x.data = [[1, 2], [3, 4], + [5, 6], [7, 8], [9, 10], [11, 12]] + x.shape = [6, 2] + + LoD can have multiple levels (for example, a paragraph can have multiple + sentences and a sentence can have multiple words). In the following + LodTensor Y, the lod_level is 2. It means there are 2 sequence, the + first sequence length is 2 (has 2 sub-sequences), the second one's + length is 1. The first sequence's 2 sub-sequences have length 2 and 2, + respectively. And the second sequence's 1 sub-sequence has length 3. + + y.lod = [[2 1], [2 2 3]] + y.shape = [2+2+3, ...] + + Note: + In above description, LoD is length-based. In Paddle internal + implementation, lod is offset-based. Hence, internally, + y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based + equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]). + + Sometimes LoD is called recursive_sequence_length to be more + self-explanatory. In this case, it must be length-based. Due to history + reasons. when LoD is called lod in public API, it might be offset-based. + Users should be careful about it. + + )DOC") .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("__init__", @@ -620,7 +663,23 @@ All parameter, weight, gradient are variables in Paddle. // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); - py::class_ exec_strategy(pe, "ExecutionStrategy"); + py::class_ exec_strategy(pe, "ExecutionStrategy", R"DOC( + ExecutionStrategy allows the user to more preciously control how to run + the program in ParallelExecutor by setting the property. + + The available properties include: + use_cuda (bool): Whether to use CUDA or not. Default True. + num_threads (int): The number of threads that used to run the + operators in ParallelExecutor. If it is not set, it will be + set in ParallelExecutor according to the device count. + Default 0. + allow_op_delay (bool): Whether to delay the communication operators + to run. Default False. + num_iteration_per_drop_scope (int): how many iterations between + the two dropping local scopes. Default 100. + + )DOC"); + exec_strategy.def(py::init()) .def_property( "num_threads", @@ -658,7 +717,25 @@ All parameter, weight, gradient are variables in Paddle. : ExecutionStrategy::kDefault; }); - py::class_ build_strategy(pe, "BuildStrategy"); + py::class_ build_strategy(pe, "BuildStrategy", R"DOC( + BuildStrategy allows the user to more preciously control how to + build the SSA Graph in ParallelExecutor by setting the property. + + The available properties include: + reduce_strategy (str): There are two reduce strategies, 'AllReduce' + and 'Reduce'. If you want that all parameters will be optimized + on all devices, you can choose 'AllReduce'; if you choose + 'Reduce', all parameters will be evenly allocated to different + devices for optimization, and then broadcast the optimized + parameter to other devices. Default 'AllReduce'. + gradient_scale_strategy (str): There are two ways of defining loss@grad, + 'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor + sets the loss@grad according to the number of devices. If you want + to customize loss@grad, you can choose 'Customized'. + Default 'CoeffNumDevice'. + debug_graphviz_path (str): Whether to write the SSA Graph to file in the + form of graphviz. It is useful for debugging. Default "". +)DOC"); py::enum_(build_strategy, "ReduceStrategy") .value("Reduce", BuildStrategy::ReduceStrategy::kReduce) diff --git a/paddle/fluid/train/CMakeLists.txt b/paddle/fluid/train/CMakeLists.txt index 6cd9cbe379874e5ab7e40c1349e0483ff45bb63a..fae28fcb4c3102240438b62c203c65281f029192 100644 --- a/paddle/fluid/train/CMakeLists.txt +++ b/paddle/fluid/train/CMakeLists.txt @@ -4,7 +4,6 @@ function(train_test TARGET_NAME) set(multiValueArgs ARGS) cmake_parse_arguments(train_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(arg_list "") if(train_test_ARGS) foreach(arg ${train_test_ARGS}) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index d9214d0b8cef948dfca2ac9b1e3597fefc990786..b97e63ecc85872ebccb21f37bb4f2f00d49b4599 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -598,9 +598,9 @@ EOF EOF if [[ ${WITH_GPU} == "ON" ]]; then - NCCL_DEPS="apt-get install -y --allow-downgrades libnccl2=2.2.13-1+cuda${CUDA_MAJOR} libnccl-dev=2.2.13-1+cuda${CUDA_MAJOR} &&" + NCCL_DEPS="apt-get install -y --allow-downgrades libnccl2=2.2.13-1+cuda${CUDA_MAJOR} libnccl-dev=2.2.13-1+cuda${CUDA_MAJOR} || true" else - NCCL_DEPS="" + NCCL_DEPS="true" fi if [[ ${WITH_FLUID_ONLY:-OFF} == "OFF" ]]; then @@ -614,9 +614,8 @@ EOF cat >> ${PADDLE_ROOT}/build/Dockerfile < self.max_norm: + output = self.max_norm * y_np / norm + else: + output = y_np + self.assertTrue( + np.allclose( + np.array(out_tensor), output, atol=1e-5, equal_nan=False)) + + def test_clip_by_norm_with_selected_ros(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + + for place in places: + self.check_with_place(place) + + def config_test_case(self): + self.max_norm = 1.0 + self.max_relative_error = 0.006 + self.grad_shape = (4, 1) + self.grad_clipped_shape = (3, 1) + self.grad_rows = [0, 0, 1, 2] + self.grad_clipped_rows = [0, 1, 2] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 775c2253ab3b27708b745b85fc007fcb504d1aed..6a129b6df9bf1830fdf5eb5cb9ae0c5e4f7bb4ec 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -16,6 +16,8 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator class ElementwiseMulOp(OpTest): @@ -115,5 +117,56 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): } +class TestElementWiseMulSelectedRows(OpTest): + def setUp(self): + self.rows = [0, 1, 2, 3, 4, 5, 6] + self.feature = 12 + self.height = 100 + self.input_shape = (len(self.rows), self.feature) + + def prepare_input(self, scope, place): + self.input = { + "X": np.random.random(self.input_shape).astype("float32"), + "Y": np.random.random(self.input_shape).astype("float32") + } + + def init_input(in_name): + x_selected_rows = scope.var(in_name).get_selected_rows() + x_selected_rows.set_height(self.height) + x_selected_rows.set_rows(self.rows) + x_array = self.input[in_name] + x_tensor = x_selected_rows.get_tensor() + x_tensor.set(x_array, place) + + init_input("X") + init_input("Y") + + def create_out_selected_row(self, scope): + return scope.var('Out').get_selected_rows() + + def check_result(self, out_selected_rows): + assert out_selected_rows.height() == self.height + assert out_selected_rows.rows() == self.rows + out_tensor = np.array(out_selected_rows.get_tensor()) + assert out_tensor.shape == self.input_shape + + def check_with_place(self, place): + scope = core.Scope() + self.prepare_input(scope, place) + + out_selected_rows = self.create_out_selected_row(scope) + out_selected_rows.set_height(0) + out_selected_rows.set_rows([]) + + elementwise_mul = Operator("elementwise_mul", X='X', Y='Y', Out='Out') + elementwise_mul.run(scope, place) + self.check_result(out_selected_rows) + + def test_elewisemul_with_selected_rows_input(self): + places = [core.CPUPlace()] + for place in places: + self.check_with_place(place) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_embedding_fc_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fused_embedding_fc_lstm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..70ca521d3387ac11cd41d8496b4d094667232d4c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_embedding_fc_lstm_op.py @@ -0,0 +1,218 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +from test_lstm_op import lstm, ACTIVATION + + +def fc(x, w, b): + return np.dot(x, w) + b + + +def fused_embedded_fc_lstm( + ids, # T x 1 + lod, # 1 x N + embeddings=None, # Dict_size x M + wx=None, # M x 4D + bx=None, # 1 x 4D + h0=None, # N x D + c0=None, # N x D + w_h=None, # D x 4D + w_b=None, # 1 x 4D + w_c=None, # 1 x 3D + is_reverse=False, + act_gate=None, + act_cell=None, + act_cand=None): + # Make a lookup for embeddings and pass result into lstm reference + T = ids.shape[0] + M = embeddings.shape[1] + x = embeddings[ids].reshape([T, M]) + return lstm( + fc(x, wx, bx), lod, h0, c0, w_h, w_b, w_c, is_reverse, act_gate, + act_cell, act_cand) + + +class TestFusionLSTMOp(OpTest): + def set_conf(self): + pass + + def setUp(self): + self.op_type = 'fused_embedding_fc_lstm' + self.lod = [[2, 3, 5, 4]] + self.M = 8 # Embedding size + self.D = 16 # Hidden size + self.dict_size = 18 + self.has_initial_state = False + self.use_peepholes = False + self.is_reverse = False + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.set_conf() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + + # this is the weight of fc + wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') + # this is the bias of fc + bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') + + if self.use_peepholes: + b = np.random.normal(size=(1, 7 * self.D)).astype('float32') + else: + b = np.random.normal(size=(1, 4 * self.D)).astype('float32') + w_b = np.copy(b[:, 0:4 * self.D]) + w_c = b[:, 4 * self.D:] if self.use_peepholes else None + + # low is 0 , high is voc_size - 1 + ids = np.random.randint( + low=0, high=self.dict_size - 1, size=(T, 1)).astype("int64") + # embeddings as they were trained , so each entry is of M size + embeddings = np.random.random( + (self.dict_size, self.M)).astype("float32") + + # multiply embeddings via Weights + fc_embeddings = np.dot(embeddings, wx) + + # bias should be manually added into the bias of this fused embedding fc LSTM + b[0, 0:4 * self.D] += bx[0, :] + combined_biases = b[:, 0:4 * self.D] + # So let broadcast it , so they can be added + ones = np.ones([self.dict_size, 1]) + broadcasted_biases = np.dot(ones, combined_biases) + # Sum biases with Wx*embeddings + fc_embeddings += broadcasted_biases + + if self.has_initial_state: + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') + else: + h0 = np.zeros((bs, self.D)).astype('float32') + c0 = np.zeros((bs, self.D)).astype('float32') + + wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') + + h, c = fused_embedded_fc_lstm( + ids, self.lod, embeddings, wx, bx, h0, c0, wh, w_b, w_c, + self.is_reverse, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], ACTIVATION[self.act_cand]) + + self.inputs = { + 'Ids': (ids, self.lod), + 'Embeddings': fc_embeddings, + 'WeightH': wh, + 'Bias': b + } + + if self.has_initial_state: + self.inputs['H0'] = h0 + self.inputs['C0'] = c0 + + self.outputs = { + 'Hidden': (h, self.lod), + 'Cell': (c, self.lod), + } + self.attrs = { + 'use_peepholes': self.use_peepholes, + 'is_reverse': self.is_reverse, + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand + } + + def test_check_output(self): + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output() + + +class TestFusionLSTMOpInit(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + + +class TestFusionLSTMOpReverse(TestFusionLSTMOp): + def set_conf(self): + self.is_reverse = True + + +class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + self.is_reverse = True + + +class TestFusionLSTMOpMD1(TestFusionLSTMOp): + def set_conf(self): + self.M = 36 + self.D = 8 + + +class TestFusionLSTMOpMD2(TestFusionLSTMOp): + def set_conf(self): + self.M = 8 + self.D = 8 + + +class TestFusionLSTMOpMD3(TestFusionLSTMOp): + def set_conf(self): + self.M = 15 + self.D = 3 + + +class TestFusionLSTMOpBS1(TestFusionLSTMOp): + def set_conf(self): + self.lod = [[3]] + self.D = 16 + + +class TestFusionLSTMOpPeepholes(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + + +class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + + +class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.is_reverse = True + + +class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + self.is_reverse = True + + +class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.lod = [[2]] + self.D = 8 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_isfinite_op.py b/python/paddle/fluid/tests/unittests/test_isfinite_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d96ae15c7288c9a8d585d8d70d2aa8922b8f22b3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_isfinite_op.py @@ -0,0 +1,97 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestInf(OpTest): + def setUp(self): + self.op_type = "isinf" + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + x[0] = np.inf + x[-1] = np.inf + + self.inputs = {'X': x} + self.outputs = {'Out': np.array(True).astype(self.dtype)} + + def init_dtype(self): + pass + + def test_output(self): + self.check_output() + + +class TestFP16Inf(TestInf): + def init_dtype(self): + self.dtype = np.float16 + + +class TestNAN(OpTest): + def setUp(self): + self.op_type = "isnan" + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + x[0] = np.nan + x[-1] = np.nan + + self.inputs = {'X': x} + self.outputs = {'Out': np.array(True).astype(self.dtype)} + + def init_dtype(self): + pass + + def test_output(self): + self.check_output() + + +class TestFP16NAN(TestNAN): + def init_dtype(self): + self.dtype = np.float16 + + +class TestIsfinite(OpTest): + def setUp(self): + self.op_type = "isfinite" + self.dtype = np.float32 + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + x[0] = np.inf + x[-1] = np.nan + out = np.isinf(x) | np.isnan(x) + + self.inputs = {'X': x} + self.outputs = {'Out': np.array(False).astype(self.dtype)} + + def init_dtype(self): + pass + + def test_output(self): + self.check_output() + + +class TestFP16Isfinite(TestIsfinite): + def init_dtype(self): + self.dtype = np.float16 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 328f0f0011381b77cccb8b2d9b266aa53b259473..8fc8125a773543eea768783155ad152c475535b5 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -243,5 +243,87 @@ class TestKeepDimReduceSumMultiAxises(OpTest): self.check_grad(['X'], 'Out') +class TestReduceSumWithDimOne(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((10, 1, 1)).astype("float64")} + self.attrs = {'dim': [1, 2], 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), + keepdims=True) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceSumWithNumelOne(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((1, 1)).astype("float64")} + self.attrs = {'dim': [1], 'keep_dim': False} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), + keepdims=False) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceMeanWithDimOne(OpTest): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((10, 1, 1)).astype("float64")} + self.attrs = {'dim': [1], 'keep_dim': False} + self.outputs = { + 'Out': self.inputs['X'].mean( + axis=tuple(self.attrs['dim']), keepdims=False) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceMeanWithNumelOne(OpTest): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((1, 1)).astype("float64")} + self.attrs = {'dim': [1], 'keep_dim': True} + self.outputs = { + 'Out': self.inputs['X'].mean( + axis=tuple(self.attrs['dim']), keepdims=True) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceAll(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = {'X': np.random.random((1, 1, 1)).astype("float64")} + self.attrs = {'reduce_all': True, 'keep_dim': False} + self.outputs = {'Out': self.inputs['X'].sum()} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index ecdbe27f4d90268d755a712e25289cfaf4715f29..91db85b8ec6a32fee3b7aa8ab76429a4a197fcc3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -788,7 +788,8 @@ in a single call.") tuple: (main_program, startup_program), of type "Program" """ pserver_prog = self.get_pserver_program(endpoint) - pserver_startup = self.get_startup_program(endpoint) + pserver_startup = self.get_startup_program( + endpoint, pserver_program=pserver_prog) return pserver_prog, pserver_startup def get_startup_program(self, diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 43d51b03e81895d7322d9e28a9c40b6d7cc69206..c402535b27142e94af339a6c18401ba20bc6564d 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -124,7 +124,7 @@ class InferenceTranspiler(object): next_op = self.block.ops[i + 1] if next_op.type == 'relu': # modify bnorm OP to include relu - current_op.set_attr("fuse_relu", True) + current_op._set_attr("fuse_relu", True) # remove relu OP self.block._remove_op(i + 1) i = i + 1 @@ -454,7 +454,7 @@ class InferenceTranspiler(object): :type eltwise_op: Operator ''' - conv_op.set_attr("fuse_eltwise", True) + conv_op._set_attr("fuse_eltwise", True) self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0] self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0] diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 5b9459b670ac8583ee0e65a3c1b51f6248bb6303..b2ef9f75809004d9df0003217c2dafcd69e83890 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -15,7 +15,7 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader', - 'multiprocess_reader' + 'multiprocess_reader', 'Fake' ] from threading import Thread @@ -504,3 +504,39 @@ class PipeReader: yield decomp_buff else: break + + +class Fake(object): + """ + fake reader will cache the first data it read and yield it out for data_num times. + It is used to cache a data from real reader and use it for speed testing. + + :param reader: the origin reader + :param data_num: times that this reader will yield data. + + :return: a fake reader. + + Examples: + .. code-block:: python + + def reader(): + for i in range(10): + yield i + + fake_reader = Fake()(reader, 100) + """ + + def __init__(self): + self.data = None + self.yield_num = 0 + + def __call__(self, reader, data_num): + def fake_reader(): + if self.data is None: + self.data = next(reader()) + while self.yield_num < data_num: + yield self.data + self.yield_num += 1 + self.yield_num = 0 + + return fake_reader diff --git a/python/paddle/reader/tests/decorator_test.py b/python/paddle/reader/tests/decorator_test.py index c324092f8850e4bd64955aa9c987746b5cec54b5..b9af8348e16c051db64d57a9594aee303d83aef2 100644 --- a/python/paddle/reader/tests/decorator_test.py +++ b/python/paddle/reader/tests/decorator_test.py @@ -203,5 +203,21 @@ class TestMultiProcessReader(unittest.TestCase): self.reader_test(use_pipe=True) +class TestFakeReader(unittest.TestCase): + def test_fake_reader(self): + def reader(): + for i in range(10): + yield i + + data_num = 100 + fake_reader = paddle.reader.Fake()(reader, data_num) + for _ in range(10): + i = 0 + for data in fake_reader(): + self.assertEqual(data, 0) + i += 1 + self.assertEqual(i, data_num) + + if __name__ == '__main__': unittest.main()