diff --git a/cmake/configure.cmake b/cmake/configure.cmake index e4af34d10ed92c501dd805addb62747c91c00978..c35096e09b5685ee30f20648e7dd461f71e0b1c4 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -97,6 +97,14 @@ if(WITH_GPU) endif() include_directories(${TENSORRT_INCLUDE_DIR}) endif() + if(WITH_ANAKIN) + if(${CUDA_VERSION_MAJOR} VERSION_LESS 8) + message(FATAL_ERROR "Anakin needs CUDA >= 8.0 to compile") + endif() + if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7) + message(FATAL_ERROR "Anakin needs CUDNN >= 7.0 to compile") + endif() + endif() elseif(WITH_AMD_GPU) add_definitions(-DPADDLE_WITH_HIP) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__") diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake index 8b7d91f234594becdda805c089fac0bb4e4e8e44..403873a51017b9bb7c888bed77d89ca9b15b68d2 100644 --- a/cmake/external/anakin.cmake +++ b/cmake/external/anakin.cmake @@ -2,10 +2,22 @@ if (NOT WITH_ANAKIN) return() endif() -set(ANAKIN_INSTALL_DIR "${THIRD_PARTY_PATH}/install/anakin" CACHE PATH - "Anakin install path." FORCE) -set(ANAKIN_INCLUDE "${ANAKIN_INSTALL_DIR}" CACHE STRING "root of Anakin header files") -set(ANAKIN_LIBRARY "${ANAKIN_INSTALL_DIR}" CACHE STRING "path of Anakin library") +INCLUDE(ExternalProject) +set(ANAKIN_SOURCE_DIR ${THIRD_PARTY_PATH}/anakin) +# the anakin install dir is only default one now +set(ANAKIN_INSTALL_DIR ${THIRD_PARTY_PATH}/anakin/src/extern_anakin/output) +set(ANAKIN_INCLUDE ${ANAKIN_INSTALL_DIR}) +set(ANAKIN_LIBRARY ${ANAKIN_INSTALL_DIR}) +set(ANAKIN_SHARED_LIB ${ANAKIN_LIBRARY}/libanakin.so) +set(ANAKIN_SABER_LIB ${ANAKIN_LIBRARY}/libanakin_saber_common.so) + +# TODO(luotao): ANAKIN_MODLE_URL will move to demo ci later. +set(ANAKIN_MODLE_URL "http://paddle-inference-dist.bj.bcebos.com/mobilenet_v2.anakin.bin") +execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_SOURCE_DIR}") +execute_process(COMMAND bash -c "cd ${ANAKIN_SOURCE_DIR}; wget -q --no-check-certificate ${ANAKIN_MODLE_URL}") + +include_directories(${ANAKIN_INCLUDE}) +include_directories(${ANAKIN_INCLUDE}/saber/) set(ANAKIN_COMPILE_EXTRA_FLAGS -Wno-error=unused-but-set-variable -Wno-unused-but-set-variable @@ -20,36 +32,33 @@ set(ANAKIN_COMPILE_EXTRA_FLAGS -Wno-reorder -Wno-error=cpp) -set(ANAKIN_LIBRARY_URL "https://github.com/pangge/Anakin/releases/download/Version0.1.0/anakin.tar.gz") - -# A helper function used in Anakin, currently, to use it, one need to recursively include -# nearly all the header files. -function(fetch_include_recursively root_dir) - if (IS_DIRECTORY ${root_dir}) - include_directories(${root_dir}) - endif() - - file(GLOB ALL_SUB RELATIVE ${root_dir} ${root_dir}/*) - foreach(sub ${ALL_SUB}) - if (IS_DIRECTORY ${root_dir}/${sub}) - fetch_include_recursively(${root_dir}/${sub}) - endif() - endforeach() -endfunction() - -if (NOT EXISTS "${ANAKIN_INSTALL_DIR}") - # download library - message(STATUS "Download Anakin library from ${ANAKIN_LIBRARY_URL}") - execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}") - execute_process(COMMAND bash -c "rm -rf ${ANAKIN_INSTALL_DIR}/*") - execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; wget --no-check-certificate -q ${ANAKIN_LIBRARY_URL}") - execute_process(COMMAND bash -c "mkdir -p ${ANAKIN_INSTALL_DIR}") - execute_process(COMMAND bash -c "cd ${ANAKIN_INSTALL_DIR}; tar xzf anakin.tar.gz") -endif() +ExternalProject_Add( + extern_anakin + ${EXTERNAL_PROJECT_LOG_ARGS} + # TODO(luotao): use PaddlePaddle/Anakin later + GIT_REPOSITORY "https://github.com/luotao1/Anakin" + GIT_TAG "3957ae9263eaa0b1986758dac60a88852afb09be" + PREFIX ${ANAKIN_SOURCE_DIR} + UPDATE_COMMAND "" + CMAKE_ARGS -DUSE_GPU_PLACE=YES + -DUSE_X86_PLACE=YES + -DBUILD_WITH_UNIT_TEST=NO + -DPROTOBUF_ROOT=${THIRD_PARTY_PATH}/install/protobuf + -DMKLML_ROOT=${THIRD_PARTY_PATH}/install/mklml + -DCUDNN_ROOT=${CUDNN_ROOT} + ${EXTERNAL_OPTIONAL_ARGS} + CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ANAKIN_INSTALL_DIR} +) -if (WITH_ANAKIN) - message(STATUS "Anakin for inference is enabled") - message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}") - fetch_include_recursively(${ANAKIN_INCLUDE}) - link_directories(${ANAKIN_LIBRARY}) -endif() +message(STATUS "Anakin for inference is enabled") +message(STATUS "Anakin is set INCLUDE:${ANAKIN_INCLUDE} LIBRARY:${ANAKIN_LIBRARY}") + +add_library(anakin_shared SHARED IMPORTED GLOBAL) +set_property(TARGET anakin_shared PROPERTY IMPORTED_LOCATION ${ANAKIN_SHARED_LIB}) +add_dependencies(anakin_shared extern_anakin protobuf mklml) + +add_library(anakin_saber SHARED IMPORTED GLOBAL) +set_property(TARGET anakin_saber PROPERTY IMPORTED_LOCATION ${ANAKIN_SABER_LIB}) +add_dependencies(anakin_saber extern_anakin protobuf mklml) + +list(APPEND external_project_dependencies anakin_shared anakin_saber) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index aeb081e76e5bc5b9d3d81ce625195c800174ab6c..834ab5a9e527355d3664313d38cd4920f6fbf535 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -143,7 +143,7 @@ if (WITH_ANAKIN AND WITH_GPU) copy(anakin_inference_lib DEPS paddle_inference_api inference_anakin_api SRCS ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/libinference_anakin_api* # compiled anakin api - ${PADDLE_BINARY_DIR}/third_party/install/anakin/*.tar.gz # anakin release + ${ANAKIN_INSTALL_DIR} # anakin release DSTS ${dst_dir}/inference/anakin ${dst_dir}/inference/anakin) list(APPEND inference_deps anakin_inference_lib) endif() diff --git a/doc/fluid/api/executor.rst b/doc/fluid/api/executor.rst index db2842e7f23e74130a966bb347004bee1ccb08fd..f23ecc1f80030f20359ce9675130a167722606c9 100644 --- a/doc/fluid/api/executor.rst +++ b/doc/fluid/api/executor.rst @@ -38,11 +38,3 @@ _switch_scope .. autofunction:: paddle.fluid.executor._switch_scope :noindex: -.. _api_fluid_executor_fetch_var: - -fetch_var ---------- - -.. autofunction:: paddle.fluid.executor.fetch_var - :noindex: - diff --git a/doc/fluid/api/fluid.rst b/doc/fluid/api/fluid.rst index 51cdfe0c2ed045a5b3247c4fdec9868d756eae86..7eab58355c3648d929d3b5d98984adce9034f016 100644 --- a/doc/fluid/api/fluid.rst +++ b/doc/fluid/api/fluid.rst @@ -106,22 +106,6 @@ _switch_scope .. autofunction:: paddle.fluid._switch_scope :noindex: -.. _api_fluid_fetch_var: - -fetch_var ---------- - -.. autofunction:: paddle.fluid.fetch_var - :noindex: - -.. _api_fluid_Go: - -Go --- - -.. autoclass:: paddle.fluid.Go - :members: - :noindex: .. _api_fluid_make_channel: diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index dd172ff9c97814c089ddb2e5bf729880cf0c9cdb..2ce73df0248bee244665b033edddcea70407546d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -34,21 +34,10 @@ paddle.fluid.default_main_program ArgSpec(args=[], varargs=None, keywords=None, paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Executor.as_lodtensor ArgSpec(args=['self', 'data'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)) paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None) paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) -paddle.fluid.fetch_var ArgSpec(args=['name', 'scope', 'return_numpy'], varargs=None, keywords=None, defaults=(None, True)) -paddle.fluid.Go.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.Go.construct_go_op ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) -paddle.fluid.make_channel ArgSpec(args=['dtype', 'capacity'], varargs=None, keywords=None, defaults=(0,)) -paddle.fluid.channel_send ArgSpec(args=['channel', 'value', 'is_copy'], varargs=None, keywords=None, defaults=(False,)) -paddle.fluid.channel_recv ArgSpec(args=['channel', 'return_value'], varargs=None, keywords=None, defaults=None) -paddle.fluid.channel_close ArgSpec(args=['channel'], varargs=None, keywords=None, defaults=None) -paddle.fluid.Select.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.Select.case ArgSpec(args=['self', 'channel_action_fn', 'channel', 'value', 'is_copy'], varargs=None, keywords=None, defaults=(False,)) -paddle.fluid.Select.default ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Trainer.__init__ ArgSpec(args=['self', 'train_func', 'optimizer_func', 'param_path', 'place', 'parallel', 'checkpoint_config'], varargs=None, keywords=None, defaults=(None, None, False, None)) paddle.fluid.Trainer.save_params ArgSpec(args=['self', 'param_path'], varargs=None, keywords=None, defaults=None) paddle.fluid.Trainer.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) @@ -62,20 +51,16 @@ paddle.fluid.CheckpointConfig.__init__ ArgSpec(args=['self', 'checkpoint_dir', ' paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path', 'place', 'parallel'], varargs=None, keywords=None, defaults=(None, False)) paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.InferenceTranspiler.__init__ -paddle.fluid.InferenceTranspiler.fuse_batch_norm ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=None) -paddle.fluid.InferenceTranspiler.fuse_relu_mkldnn ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=None) paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) paddle.fluid.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspilerConfig.__init__ paddle.fluid.ParallelExecutor.__init__ ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id'], varargs=None, keywords='kwargs', defaults=(None, None, None, None, None, 1, 0)) -paddle.fluid.ParallelExecutor.bcast_params ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.ParallelExecutor.run ArgSpec(args=['self', 'fetch_list', 'feed', 'feed_dict', 'return_numpy'], varargs=None, keywords=None, defaults=(None, None, True)) paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ExecutionStrategy) -> None paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.GradientScaleStrategy, arg0: int) -> None @@ -338,14 +323,11 @@ paddle.fluid.contrib.BeamSearchDecoder.read_array ArgSpec(args=['self', 'init', paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array', 'value'], varargs=None, keywords=None, defaults=None) paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.transpiler.DistributeTranspiler.create_splited_vars ArgSpec(args=['self', 'source_var', 'block', 'tag'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.transpiler.InferenceTranspiler.__init__ -paddle.fluid.transpiler.InferenceTranspiler.fuse_batch_norm ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=None) -paddle.fluid.transpiler.InferenceTranspiler.fuse_relu_mkldnn ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0)) paddle.fluid.transpiler.release_memory ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index c4ab26a2288bb9d8f3cd54a797d2062e0606b219..9318f1089781b30468cf4d3c7151d0dd26e50a9c 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -44,13 +44,13 @@ class DfgPassManagerImpl final : public DfgPassManager { if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { auto trt_teller = [&](const Node* node) { std::unordered_set teller_set( - {"elementwise_add", "mul", "conv2d", "pool2d", "relu"}); + {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"}); if (!node->IsFunction()) return false; const auto* func = static_cast(node); - if (teller_set.count(func->func_type())) + if (teller_set.count(func->func_type())) { return true; - else { + } else { return false; } }; diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 08d0f493ab30d92a121d089d9003bc575429b4dd..83867e0a2c198f265cb36be3a71546796dfbec67 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -45,7 +45,6 @@ endfunction(inference_api_test) cc_library(paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor) - cc_test(test_paddle_inference_api SRCS api_tester.cc DEPS paddle_inference_api) @@ -62,22 +61,18 @@ inference_api_test(test_api_tensorrt_subgraph_engine SRC api_tensorrt_subgraph_e endif() if (WITH_ANAKIN) # only needed in CI - # Due to Anakin do not have official library releases and the versions of protobuf and cuda do not match Paddle's, - # so anakin library will not be merged to our official inference library. To use anakin prediction API, one need to - # compile the libinference_anakin_api.a and compile with anakin.so. - fetch_include_recursively(${ANAKIN_INCLUDE}) # compile the libinference_anakin_api.a and anakin.so. - nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc) - nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc) - target_compile_options(inference_anakin_api BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) - target_compile_options(inference_anakin_api_shared BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) - target_link_libraries(inference_anakin_api anakin anakin_saber_common) - target_link_libraries(inference_anakin_api_shared anakin anakin_saber_common) + nv_library(inference_anakin_api SRCS api.cc api_anakin_engine.cc DEPS anakin_shared anakin_saber) + #nv_library(inference_anakin_api_shared SHARED SRCS api.cc api_anakin_engine.cc DEPS anakin) + function(anakin_target target_name) + target_compile_options(${target_name} BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) + endfunction() + anakin_target(inference_anakin_api) + #anakin_target(inference_anakin_api_shared) if (WITH_TESTING) - # this test is unstable, disable it first. - #cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc - #ARGS --model=${ANAKIN_INSTALL_DIR}/mobilenet_v2.anakin.bin - #DEPS inference_anakin_api_shared) - #target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) - endif(WITH_TESTING) + cc_test(inference_anakin_test SRCS api_anakin_engine_tester.cc + ARGS --model=${ANAKIN_SOURCE_DIR}/mobilenet_v2.anakin.bin + DEPS inference_anakin_api dynload_cuda SERIAL) + target_compile_options(inference_anakin_test BEFORE PUBLIC ${ANAKIN_COMPILE_EXTRA_FLAGS}) + endif(WITH_TESTING) endif() diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index b24414e8245b1a4d90acce4fa1ad5690e06b47dd..794534467be066e91db2b4c204913ab2cf12dbfd 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -45,7 +45,7 @@ class PaddleBuf { PaddleBuf(void* data, size_t length) : data_(data), length_(length), memory_owned_{false} {} // Own memory. - PaddleBuf(size_t length) + explicit PaddleBuf(size_t length) : data_(new char[length]), length_(length), memory_owned_(true) {} // Resize to `length` bytes. void Resize(size_t length); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 8f42a37cd3f8978b917b42e8f45a128b8422aa57..6863b035d8cd9dfb21aed3947226a796778912a4 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,7 +1,7 @@ # Add TRT tests nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc -activation_op.cc +activation_op.cc softmax_op.cc DEPS tensorrt_engine operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS @@ -21,3 +21,6 @@ nv_test(test_trt_pool2d_op SRCS test_pool2d_op.cc pool2d_op.cc nv_test(test_trt_elementwise_op SRCS test_elementwise_op.cc elementwise_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine elementwise_add_op SERIAL) + +nv_test(test_trt_softmax_op SRCS test_softmax_op.cc softmax_op.cc + DEPS ${FLUID_CORE_MODULES} tensorrt_engine softmax_op SERIAL) diff --git a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0064f90fd7944403c14d4d47616ea82f681ceb74 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * SoftMaxOp, ISoftMaxLayer in TRT. This Layer doesn't has weights. + */ +class SoftMaxOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) + << "convert a fluid softmax op to tensorrt softmax layer without bias"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax, + *const_cast(input1)); + + auto output_name = op_desc.Output("Out")[0]; + engine_->SetITensor(output_name, layer->getOutput(0)); + if (test_mode) { + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(softmax); +REGISTER_TRT_OP_CONVERTER(softmax, SoftMaxOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..503ce71f7fb4377bb4304569b7484fb25abdb284 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_softmax_op.cc @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(SoftMaxOpConverter, main) { + framework::Scope scope; + std::unordered_set parameters; + TRTConvertValidation validator(8, parameters, scope, 1000); + + std::vector tensor_shape{8, 10}; + validator.DeclInputVar("softmax-X", tensor_shape, + nvinfer1::DimsCHW(10, 1, 1)); + validator.DeclOutputVar("softmax-Out", nvinfer1::DimsCHW(10, 1, 1)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("softmax"); + desc.SetInput("X", {"softmax-X"}); + desc.SetOutput("Out", {"softmax-Out"}); + + LOG(INFO) << "set OP"; + validator.SetOp(*desc.Proto()); + LOG(INFO) << "execute"; + + validator.Execute(3); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(softmax); diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 63c2f978f253df11100ecca83acae5eab6a0337d..4265f33f28fe36b1745baf4761c3c85e3a281d6b 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -79,6 +79,12 @@ class TRTConvertValidation { } // Declare a Variable as input with random initialization. + void DeclInputVar(const std::string& name, const std::vector tensor_dims, + const nvinfer1::Dims& trt_dims) { + DeclVar(name, tensor_dims); + engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, trt_dims); + } + void DeclInputVar(const std::string& name, const nvinfer1::Dims& dims) { DeclVar(name, dims); // Declare TRT inputs. @@ -94,12 +100,18 @@ class TRTConvertValidation { DeclVar(name, dims); } - // Declare a variable in a fluid Scope. - void DeclVar(const std::string& name, const nvinfer1::Dims& dims, - bool is_param = false) { + void DeclVar(const std::string& name, const std::vector dim_vec) { platform::CPUPlace place; platform::CPUDeviceContext ctx(place); + auto* x = scope_.Var(name); + auto* x_tensor = x->GetMutable(); + x_tensor->Resize(framework::make_ddim(dim_vec)); + RandomizeTensor(x_tensor, place, ctx); + } + // Declare a variable in a fluid Scope. + void DeclVar(const std::string& name, const nvinfer1::Dims& dims, + bool is_param = false) { // Init Fluid tensor. std::vector dim_vec(dims.d, dims.d + dims.nbDims); // There is no batchsize in ITensor's shape, but We should add it to @@ -107,10 +119,8 @@ class TRTConvertValidation { // if_add_batch_ flag is true, add the max batchsize to dim_vec. if (is_param != true && if_add_batch_ == true) dim_vec.insert(dim_vec.begin(), max_batch_size_); - auto* x = scope_.Var(name); - auto* x_tensor = x->GetMutable(); - x_tensor->Resize(framework::make_ddim(dim_vec)); - RandomizeTensor(x_tensor, place, ctx); + + DeclVar(name, dim_vec); } void SetOp(const framework::proto::OpDesc& desc) { diff --git a/paddle/fluid/operators/elementwise_add_op.cu b/paddle/fluid/operators/elementwise_add_op.cu index 6cbf6066c92b02eb75922587f9da5192bed15580..dfff518f170b56d180b6883c363effb8dbd677b6 100644 --- a/paddle/fluid/operators/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise_add_op.cu @@ -16,60 +16,6 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/platform/float16.h" -namespace paddle { -namespace operators { - -template -__global__ void ElementwiseAddCUDAKernel(const T *x, const T *y, T *z, int n, - int post, int size) { - int idx_x = threadIdx.x + blockIdx.x * blockDim.x; - if (idx_x < size) { - int idx_y = idx_x / post - (idx_x / (n * post)) * n; - z[idx_x] = x[idx_x] + y[idx_y]; - } -} - -template -class ElementwiseAddKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using Tensor = framework::Tensor; - - const auto x = ctx.Input("X"); - const auto y = ctx.Input("Y"); - auto z = ctx.Output("Out"); - auto *z_data = z->mutable_data(ctx.GetPlace()); - - auto &device = *(ctx.cuda_device_context().eigen_device()); - const framework::DDim &x_dim = x->dims(); - framework::DDim y_dim = y->dims(); - int size = x->numel(); - if (x_dim == y_dim) { - auto dim = framework::make_ddim({size}); - auto z_eigen = framework::EigenTensor::From(*z, dim); - auto x_eigen = framework::EigenTensor::From(*x, dim); - auto y_eigen = framework::EigenTensor::From(*y, dim); - z_eigen.device(device) = x_eigen + y_eigen; - } else { - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); - y_dim = trim_trailing_singular_dims(y_dim); - axis = (y_dim.size() == 0) ? x_dim.size() : axis; - int pre, n, post; - get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); - int threads = 512; - int grids = (size + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - ElementwiseAddCUDAKernel<<>>( - x->data(), y->data(), z_data, n, post, size); - } - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; namespace plat = paddle::platform; diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 0b19723720171a857c946880c246e2247a0023a7..5356105e2e551c0528694091608fc7585dce66d2 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -144,41 +144,16 @@ class ElementwiseAddGradKernel : public framework::OpKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); + // skip out, x, y + auto* out = dout; + auto *x = dout, *y = dout; - if (dx != nullptr) { - // In fact, we can just share memory, but it may cause a bug of memory - // optimizer - // dx->ShareDataWith(*dout); - framework::TensorCopy(*dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } - - if (dy == nullptr) return; - - const framework::DDim& x_dim = dout->dims(); - framework::DDim y_dim = dy->dims(); - if (x_dim == y_dim) { - // dy->ShareDataWith(*dout); - framework::TensorCopy(*dout, ctx.GetPlace(), - ctx.template device_context(), dy); + if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr && + dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { - dy->mutable_data(ctx.GetPlace()); - // Perform reduction to dout to calculate dy - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); - y_dim = trim_trailing_singular_dims(y_dim); - axis = (y_dim.size() == 0) ? x_dim.size() : axis; - - auto& device = - *(ctx.template device_context().eigen_device()); - int pre, n, post; - get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); - auto eigen_dout = framework::EigenTensor::From( - *dout, framework::make_ddim({pre, n, post})); - auto eigen_dy = - framework::EigenTensor::From(*dy, framework::make_ddim({n})); - eigen_dy.device(device) = eigen_dout.sum( - framework::EigenDim<2>::From(framework::make_ddim({0, 2}))); + default_elementwise_add_grad(ctx, x, y, out, dout, dx, + dy); } } }; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8e6412fc8670c37d1a3ac97d942c69ff747822ff..40ced8e1c78e24f8df2a046fd95de1d196ff3085 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -666,7 +666,7 @@ All parameter, weight, gradient are variables in Paddle. const std::string &, Scope *, std::vector &, const ExecutionStrategy &, const BuildStrategy &, size_t, size_t>()) - .def("bcast_params", &ParallelExecutor::BCastParamsToDevices) + .def("_bcast_params", &ParallelExecutor::BCastParamsToDevices) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope* diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index cccb4abe6cf14ac3e90ba59d970c9398b09b7db2..1ae05dec8de6b9cbc9568f0f4d437833be520f8d 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -48,8 +48,6 @@ from .data_feeder import DataFeeder from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope from .transpiler import DistributeTranspiler, InferenceTranspiler, \ memory_optimize, release_memory, DistributeTranspilerConfig -from .concurrency import (Go, make_channel, channel_send, channel_recv, - channel_close, Select) from .lod_tensor import create_lod_tensor, create_random_int_lodtensor from . import clip from . import profiler @@ -61,7 +59,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable Tensor = LoDTensor -__all__ = framework.__all__ + executor.__all__ + concurrency.__all__ + \ +__all__ = framework.__all__ + executor.__all__ + \ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + [ 'io', diff --git a/python/paddle/fluid/concurrency.py b/python/paddle/fluid/concurrency.py index a8c4d66720d6eda857e5960d86fc3b8ec8f11ade..676a52a917dd1f9700ec38de32932938ec339be5 100644 --- a/python/paddle/fluid/concurrency.py +++ b/python/paddle/fluid/concurrency.py @@ -19,8 +19,7 @@ from .layers import fill_constant from . import core __all__ = [ - 'Go', 'make_channel', 'channel_send', 'channel_recv', 'channel_close', - 'Select' + 'make_channel', 'channel_send', 'channel_recv', 'channel_close', 'Select' ] @@ -35,10 +34,10 @@ class Go(BlockGuard): def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: return False - self.construct_go_op() + self._construct_go_op() return super(Go, self).__exit__(exc_type, exc_val, exc_tb) - def construct_go_op(self): + def _construct_go_op(self): main_program = self.helper.main_program go_block = main_program.current_block() parent_block = main_program.block(main_program.current_block() diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index a6b9440150d6773e61cdc89a377ad10bdeeb824d..b4d998985163c0823ed606b74246da5729954eeb 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -18,9 +18,7 @@ import six from .framework import Program, default_main_program, Variable from . import core -__all__ = [ - 'Executor', 'global_scope', 'scope_guard', '_switch_scope', 'fetch_var' -] +__all__ = ['Executor', 'global_scope', 'scope_guard', '_switch_scope'] g_scope = core.Scope() @@ -171,7 +169,7 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name): return fetch_count > 0 -def fetch_var(name, scope=None, return_numpy=True): +def _fetch_var(name, scope=None, return_numpy=True): """ Fetch the value of the variable with the given name from the given scope. @@ -222,6 +220,37 @@ def _get_program_cache_key(feed, fetch_list): return str(feed_var_names + fetch_var_names) +def _as_lodtensor(data, place): + """ + Convert numpy.ndarray to Tensor, its only support Tensor without LoD information. + For higher dimensional sequence data, please use LoDTensor directly. + + Examples: + >>> import paddle.fluid as fluid + >>> place = fluid.CPUPlace() + >>> exe = fluid.executor(place) + >>> data = np.array(size=(100, 200, 300)) + >>> np_outs = map(lambda x: fluid.executor._as_lodtensor(x, place), data) + >>> ... + + Args: + data(numpy.ndarray): a instance of array + + Returns: + LoDTensor + """ + if isinstance(data, list): + raise RuntimeError("Some of your feed data hold LoD information. \ + They can not be completely cast from a list of Python \ + ndarray to LoDTensor. Please convert data to LoDTensor \ + directly before feeding the data.\ + ") + # single tensor case + tensor = core.LoDTensor() + tensor.set(data, place) + return tensor + + class Executor(object): """ An Executor in Python, only support the single-GPU running. For multi-cards, please refer to @@ -250,35 +279,6 @@ class Executor(object): self.program_caches = dict() self._closed = False - def as_lodtensor(self, data): - """ - Convert numpy.ndarray to Tensor, its only support Tensor without LoD information. - For higher dimensional sequence data, please use LoDTensor directly. - - Examples: - >>> import paddle.fluid as fluid - >>> exe = fluid.executor(fluid.CPUPlace()) - >>> data = np.array(size=(100, 200, 300)) - >>> np_outs = map(lambda x: exe.as_lodtensor(x), data) - >>> ... - - Args: - data(numpy.ndarray): a instance of array - - Returns: - LoDTensor - """ - if isinstance(data, list): - raise RuntimeError("Some of your feed data hold LoD information. \ - They can not be completely cast from a list of Python \ - ndarray to LoDTensor. Please convert data to LoDTensor \ - directly before feeding the data.\ - ") - # single tensor case - tensor = core.LoDTensor() - tensor.set(data, self.place) - return tensor - def _get_program_cache(self, program_cache_key): return self.program_caches.get(program_cache_key, None) @@ -338,7 +338,7 @@ class Executor(object): feed_target_name = op.desc.output('Out')[0] cur_feed = feed[feed_target_name] if not isinstance(cur_feed, core.LoDTensor): - cur_feed = self.as_lodtensor(cur_feed) + cur_feed = _as_lodtensor(cur_feed, self.place) idx = op.desc.attr('col') core.set_feed_variable(scope, cur_feed, feed_var_name, idx) else: diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 7c723ba264c7b03639130d0266ec58510c8b0a20..2618872cc68b54a5b8ce2f80bc4409ec9c5360fb 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -279,19 +279,19 @@ class ParallelExecutor(object): arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() if self.is_dist: - self.bcast_params() + self._bcast_params() if return_numpy: return executor.as_numpy(arr) return [arr[i] for i in range(len(arr))] - def bcast_params(self): + def _bcast_params(self): """ Broadcast the parameters to other devices. It is used during distributed training. """ - self.executor.bcast_params(set(self.persistable_vars)) + self.executor._bcast_params(set(self.persistable_vars)) @property def device_count(self): diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index c16ac22523a586bf24df8f8e8921210862e06cb9..ac1719020bf6b6883619d613c8f34ae43c4cb224 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -537,5 +537,35 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) +class TestRMSPropOptimizer(TranspilerTest): + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + optimizer = fluid.optimizer.RMSProp(learning_rate=0.1) + optimizer.minimize(avg_cost) + return + + def transpiler_test_impl(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + pserver2, startup2 = self.get_pserver(self.pserver2_ep) + + self.assertEqual(len(pserver.blocks), 3) + # block1~2: optimize pass + self.assertEqual([op.type for op in pserver.blocks[1].ops], + ["sum", "scale", "rmsprop"]) + # the variable #fc_w will be split into two blocks + fc_w_var = startup.global_block().var("fc_w.block1") + self.assertEqual(fc_w_var.shape, (500, 1000)) + moment_var = startup.global_block().var("momentum_1") + self.assertEqual(moment_var.shape, (500, 1000)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fetch_var.py b/python/paddle/fluid/tests/unittests/test_fetch_var.py index 46c3bbb6712c6276e48dd9328d7741a447f28b91..e6f37f0b4ca781e4ec83a00f8f2605ef02716bd7 100644 --- a/python/paddle/fluid/tests/unittests/test_fetch_var.py +++ b/python/paddle/fluid/tests/unittests/test_fetch_var.py @@ -26,7 +26,7 @@ class TestFetchVar(op_test.OpTest): layers.assign(input=val, output=x) exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_main_program(), feed={}, fetch_list=[]) - fetched_x = fluid.fetch_var("x") + fetched_x = fluid.executor._fetch_var("x") self.assertTrue( numpy.array_equal(fetched_x, val), "fetch_x=%s val=%s" % (fetched_x, val)) diff --git a/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py b/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py index 91b1fd2af7d8aaf85d17965f8b02c35ee3990291..f9bda5e4701f693f41fe7041ba0f5ec80b6fc31c 100644 --- a/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py +++ b/python/paddle/fluid/tests/unittests/test_py_reader_push_pop.py @@ -62,7 +62,8 @@ class TestPyReader(unittest.TestCase): next_data = np.random.uniform( low=0, high=1000, size=(batch_size, ) + shape[1:]).astype(dtype) - in_data.append(executor.as_lodtensor(next_data)) + in_data.append( + fluid.executor._as_lodtensor(next_data, place)) self.inputs.append(in_data) diff --git a/python/paddle/fluid/tests/unittests/transformer_model.py b/python/paddle/fluid/tests/unittests/transformer_model.py index 17ab875f6a555c99bbe480a54ea2df3b1f60de2d..868a0248be6833d0e8fed8a26549352562c279c1 100644 --- a/python/paddle/fluid/tests/unittests/transformer_model.py +++ b/python/paddle/fluid/tests/unittests/transformer_model.py @@ -22,7 +22,7 @@ pos_enc_param_names = ( "src_pos_enc_table", "trg_pos_enc_table", ) -batch_size = 64 +batch_size = 2 def position_encoding_init(n_position, d_pos_vec): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 04ff9d975dfc548eb164a80b9175f72b72ad0c35..466e603bdcc0fb710f91e8de1f7325ab3bec9c61 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -751,14 +751,14 @@ class DistributeTranspiler(object): out_name = op.output("Out") ids_var = program.global_block().vars[ids_name[0]] - prefetch_input_vars = self.create_splited_vars( + prefetch_input_vars = self._create_splited_vars( source_var=ids_var, block=program.global_block(), tag="_prefetch_in_") self.all_prefetch_input_vars.append(prefetch_input_vars) out_var = program.global_block().vars[out_name[0]] - prefetch_output_vars = self.create_splited_vars( + prefetch_output_vars = self._create_splited_vars( source_var=out_var, block=program.global_block(), tag="_prefetch_out_") @@ -1040,7 +1040,7 @@ class DistributeTranspiler(object): program.global_block()._sync_with_cpp() return var_mapping - def create_splited_vars(self, source_var, block, tag): + def _create_splited_vars(self, source_var, block, tag): return [ block.create_var( name=str(source_var.name + tag + str(index)), @@ -1184,18 +1184,39 @@ class DistributeTranspiler(object): program = optimize_block.program pserver_block = program.global_block() new_inputs = collections.OrderedDict() + # update param/grad shape first, then other inputs like # moment can use the updated shape + def _get_param_block(opt_op): + # param is already created on global program + param_block = None + for p in self.param_grad_ep_mapping[endpoint]["params"]: + if same_or_split_var(p.name, opt_op.input("Param")[0]): + param_block = p + break + return param_block + for key in opt_op.input_names: if key == "Grad": new_inputs[key] = merged_var + # For RMSProp optimizer + elif key == "Moment" or key == "MeanSquare": + param_block = _get_param_block(opt_op) + if not param_block: + return + moment_var = origin_program.global_block().vars[opt_op.input( + key)[0]] + tmpvar = pserver_block.create_var( + name=moment_var.name, + persistable=moment_var.persistable, + dtype=moment_var.dtype, + # change to use same shape as param + # TODO(typhoonzero): didn't append .block in the var name, + # may affect checkpoint saving? Need to verify. + shape=param_block.shape) + new_inputs[key] = tmpvar elif key == "Param": - # param is already created on global program - param_block = None - for p in self.param_grad_ep_mapping[endpoint]["params"]: - if same_or_split_var(p.name, opt_op.input(key)[0]): - param_block = p - break + param_block = _get_param_block(opt_op) if not param_block: return tmpvar = pserver_block.create_var( @@ -1221,7 +1242,7 @@ class DistributeTranspiler(object): for key in opt_op.input_names: new_shape = None - if key in ["Param", "Grad", "LearningRate"]: + if key in ["Param", "Grad", "LearningRate", "Moment", "MeanSquare"]: continue var = self.origin_program.global_block().vars[opt_op.input(key)[0]] # update accumulator variable shape diff --git a/python/paddle/fluid/transpiler/inference_transpiler.py b/python/paddle/fluid/transpiler/inference_transpiler.py index 142fa5c31d2c558e482001da73fa26d8396c3967..87f20bbccf3138585841952efacef5b0a3cbbace 100644 --- a/python/paddle/fluid/transpiler/inference_transpiler.py +++ b/python/paddle/fluid/transpiler/inference_transpiler.py @@ -57,10 +57,10 @@ class InferenceTranspiler(object): scope = global_scope() if not isinstance(scope, core.Scope): raise TypeError("scope should be as Scope type or None") - self.fuse_batch_norm(program, place, scope) - self.fuse_relu_mkldnn(program) + self._fuse_batch_norm(program, place, scope) + self._fuse_relu_mkldnn(program) - def fuse_relu_mkldnn(self, program): + def _fuse_relu_mkldnn(self, program): ''' Transpile the program by fused relu activation for MKLDNN program. @@ -104,7 +104,7 @@ class InferenceTranspiler(object): # And a better solution will be considered later. program = program.clone() - def fuse_batch_norm(self, program, place, scope): + def _fuse_batch_norm(self, program, place, scope): ''' Transpile the program by fused batch normalization. diff --git a/python/requirements.txt b/python/requirements.txt index c091ecb111bda9d5e83c3ddcae93aed0745f9e4c..f8298a63612cb217ce0e711e78fffdf86b73313d 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,5 +1,5 @@ requests==2.9.2 -numpy>=1.12 +numpy>=1.12,<=1.14 #TODO:change to ">=1.12" when numpy fix bug in 1.15 and higher version protobuf==3.1 recordio>=0.1.0 matplotlib