diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 17107e0698757997854e4627d30de60d9a9df11b..89726bf9859e71ee04c2f9380554090845fd44e5 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -109,7 +109,8 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" -"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op") +"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" +"fusion_transpose_flatten_concat_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 83c8478685ae903c1cac41337d8904c1330e7a9d..0975497a17dcdb5d2db5091dd61bd2562b5d5188 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -117,7 +117,7 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) if (NOT WIN32) -cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto) +cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference data_transform lod_tensor profiler transfer_scope_cache) else() diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index f6c82995e19ff2c24ca4447ff48fc9b3ca3d51c4..3dc571d75706b732fe9b254897b6cbd2e206cfc3 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -392,8 +392,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, int64_t max_memory_size = GetEagerDeletionThreshold(); std::unique_ptr> gc; - // WhileOp would set keep_kids to false - // WhileGradOp would need the scopes created in WhileOp + // WhileOp would set keep_kids to true, + // because WhileGradOp needs the scopes created in WhileOp. // Perhaps, we should not perform eager deletion in WhileOp // The scopes and variables created by WhileOp would be deleted // in WhileGradOp. diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index e52a8317e2113a9489f8c05bcf47bc96bea33c64..f6219a14173094d15e9c60a2e26f98da1b04ec2e 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -17,16 +17,28 @@ namespace paddle { namespace framework { +// Holds all the transfer scope across the process. std::unordered_map& global_transfer_data_cache() { - thread_local auto* x = new std::unordered_map; + typedef std::unordered_map map_t; + thread_local std::unique_ptr x(new map_t); return *x; } +// Holds all the transfer scope for this thread. std::unordered_set& global_transfer_scope_cache() { - thread_local auto* x = new std::unordered_set; + typedef std::unordered_set set_t; + thread_local std::unique_ptr x(new set_t); return *x; } +// Try to create a transfer scope. If one cached scope has match the +// requirement, just return that one. +// Inputs: +// @type0: the source kernel type. +// @type1: the target kernel type. +// @scope: the execution scope of this op. +// Returns: A scope used to hold the transfer data across the different kernel +// type. Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, const Scope* scope) { Scope* new_scope{nullptr}; @@ -46,27 +58,5 @@ Scope* TryCreateTransferScope(OpKernelType type0, OpKernelType type1, return new_scope; } -void RemoveKidsFromTransferScopeCache(Scope* scope) { - auto it = global_transfer_scope_cache().find(scope); - if (it != global_transfer_scope_cache().end()) { - global_transfer_scope_cache().erase(it); - } - for (auto* s : scope->kids()) { - auto it = global_transfer_scope_cache().find(s); - if (it != global_transfer_scope_cache().end()) { - global_transfer_scope_cache().erase(it); - } - } - - // remove global transfer data cache - auto& cache = global_transfer_data_cache(); - for (auto it = cache.begin(); it != cache.end();) { - if (it->second == scope) - it = cache.erase(it); - else - it++; - } -} - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 2c5364b72402befd2c34e5f542ce5c6b2add621d..058a5b5f460d2bd3c4c0248929dd0c87f7506930 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -4,6 +4,7 @@ endif() # analysis and tensorrt must be added before creating static library, # otherwise, there would be undefined reference to them in static library. add_subdirectory(analysis) +add_subdirectory(utils) if (TENSORRT_FOUND) add_subdirectory(tensorrt) endif() diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index e9969b84f33483b048951f704de1e13e51cbeaea..eda251c5346a6d970ecd0956f976cbef41e6c1c1 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -30,7 +30,9 @@ cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis naive_executor zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder ir_pass_manager) cc_library(zero_copy_tensor SRCS details/zero_copy_tensor.cc DEPS scope lod_tensor enforce) cc_library(zero_copy_tensor_dummy SRCS details/zero_copy_tensor_dummy.cc) -cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config analysis_config paddle_pass_builder DEPS zero_copy_tensor) +cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS + lod_tensor scope paddle_pass_builder reset_tensor_array analysis_config + analysis_config paddle_pass_builder zero_copy_tensor reset_tensor_array) cc_test(test_paddle_inference_api SRCS api_tester.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c132ce326c6b22ea235e6fb8c570678cb54e22ef..72ac534384e822b661b978a3c45c37fbf9b03060 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #endif #include "paddle/fluid/inference/utils/singleton.h" +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -174,7 +175,6 @@ bool AnalysisPredictor::Run(const std::vector &inputs, inference::Timer timer; timer.tic(); // set feed variable - std::vector feeds; framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get(); if (!SetFeed(inputs, scope)) { LOG(ERROR) << "fail to set feed"; @@ -215,17 +215,29 @@ bool AnalysisPredictor::SetFeed(const std::vector &inputs, framework::DDim ddim = framework::make_ddim(inputs[i].shape); void *input_ptr; if (inputs[i].dtype == PaddleDType::INT64) { - input_ptr = input.mutable_data(ddim, platform::CPUPlace()); + input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::FLOAT32) { - input_ptr = input.mutable_data(ddim, platform::CPUPlace()); + input_ptr = input.mutable_data(ddim, place_); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; } - // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. - std::memcpy(static_cast(input_ptr), inputs[i].data.data(), - inputs[i].data.length()); + if (platform::is_cpu_place(place_)) { + // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. + std::memcpy(static_cast(input_ptr), inputs[i].data.data(), + inputs[i].data.length()); + } else { +#ifdef PADDLE_WITH_CUDA + auto dst_gpu_place = boost::get(place_); + memory::Copy(dst_gpu_place, static_cast(input_ptr), + platform::CPUPlace(), inputs[i].data.data(), + inputs[i].data.length(), + 0); // stream 0 for sync copy +#else + PADDLE_THROW("Not compile with CUDA, should not reach here."); +#endif + } // TODO(Superjomn) Low performance, need optimization for heavy LoD copy. framework::LoD lod; for (auto &level : inputs[i].lod) { diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 66a8e513961d74b96a98b01393048112ded65482..0f88ad14b0a6c0c40b74a80d524b2b7fc4a6c5ee 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" +#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -138,7 +139,6 @@ bool NativePaddlePredictor::Run(const std::vector &inputs, Timer timer; timer.tic(); // set feed variable - std::vector feeds; framework::Scope *scope = sub_scope_ != nullptr ? sub_scope_ : scope_.get(); if (!SetFeed(inputs, scope)) { LOG(ERROR) << "fail to set feed"; @@ -194,17 +194,30 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, framework::DDim ddim = framework::make_ddim(inputs[i].shape); void *input_ptr; if (inputs[i].dtype == PaddleDType::INT64) { - input_ptr = input.mutable_data(ddim, platform::CPUPlace()); + input_ptr = input.mutable_data(ddim, place_); } else if (inputs[i].dtype == PaddleDType::FLOAT32) { - input_ptr = input.mutable_data(ddim, platform::CPUPlace()); + input_ptr = input.mutable_data(ddim, place_); } else { LOG(ERROR) << "unsupported feed type " << inputs[i].dtype; return false; } - // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. - std::memcpy(static_cast(input_ptr), inputs[i].data.data(), - inputs[i].data.length()); + if (platform::is_cpu_place(place_)) { + // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. + std::memcpy(static_cast(input_ptr), inputs[i].data.data(), + inputs[i].data.length()); + } else { +#ifdef PADDLE_WITH_CUDA + auto dst_gpu_place = boost::get(place_); + memory::Copy(dst_gpu_place, static_cast(input_ptr), + platform::CPUPlace(), inputs[i].data.data(), + inputs[i].data.length(), + 0); // stream 0 for sync copy +#else + PADDLE_THROW("Not compile with CUDA, should not reach here."); +#endif + } + // TODO(Superjomn) Low performance, need optimization for heavy LoD copy. framework::LoD lod; for (auto &level : inputs[i].lod) { diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 49683eab07a2f5bc008272038a27bdb277396284..8fb464c0f5443f116815b14324f6cbc966dc6482 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -46,8 +46,6 @@ if(WITH_GPU) endif() endif(NOT WIN32) endif() - -include_directories("D:/Paddle/") include_directories("${PADDLE_LIB}") include_directories("${PADDLE_LIB}/third_party/install/protobuf/include") include_directories("${PADDLE_LIB}/third_party/install/glog/include") diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index e8bd13037ed6c2c3c639b76f6f3561921fb6ee37..7dc88d9dd052c59aaa59b7802ee5a38ea9d89bc6 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -74,7 +74,7 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana # ocr set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") if (NOT EXISTS ${OCR_INSTALL_DIR}) - inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") + inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") endif() inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) @@ -88,31 +88,31 @@ inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet # anakin if (WITH_ANAKIN AND WITH_MKL) # only needed in CI - # anakin rnn1 - set(ANAKIN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/anakin") - set(ANAKIN_RNN1_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/rnn1") - inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn.anakin2.model.bin") - inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn_data.txt") - cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc - ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin - --datapath=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn_data.txt - DEPS inference_anakin_api_shared SERIAL) - # anakin mobilenet - if(WITH_GPU) - set(ANAKIN_MOBILENET_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/mobilenet") - inference_download(${ANAKIN_MOBILENET_INSTALL_DIR} ${INFERENCE_URL} "mobilenet_v2.anakin.bin") - cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc - ARGS --model=${ANAKIN_MOBILENET_INSTALL_DIR}/mobilenet_v2.anakin.bin - DEPS inference_anakin_api_shared dynload_cuda SERIAL) - endif() + # anakin rnn1 + set(ANAKIN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/anakin") + set(ANAKIN_RNN1_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/rnn1") + inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn.anakin2.model.bin") + inference_download(${ANAKIN_RNN1_INSTALL_DIR} ${INFERENCE_URL} "anakin_test%2Fditu_rnn_data.txt") + cc_test(test_anakin_rnn1 SRCS anakin_rnn1_tester.cc + ARGS --model=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn.anakin2.model.bin + --datapath=${ANAKIN_RNN1_INSTALL_DIR}/anakin_test%2Fditu_rnn_data.txt + DEPS inference_anakin_api_shared SERIAL) + # anakin mobilenet + if(WITH_GPU) + set(ANAKIN_MOBILENET_INSTALL_DIR "${ANAKIN_INSTALL_DIR}/mobilenet") + inference_download(${ANAKIN_MOBILENET_INSTALL_DIR} ${INFERENCE_URL} "mobilenet_v2.anakin.bin") + cc_test(test_anakin_mobilenet SRCS anakin_mobilenet_tester.cc + ARGS --model=${ANAKIN_MOBILENET_INSTALL_DIR}/mobilenet_v2.anakin.bin + DEPS inference_anakin_api_shared dynload_cuda SERIAL) + endif() endif() if(WITH_GPU AND TENSORRT_FOUND) - set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt") - if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}) - inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") - endif() - inference_analysis_test(test_trt_models SRCS trt_models_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL) + set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt") + if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}) + inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz") + endif() + inference_analysis_test(test_trt_models SRCS trt_models_tester.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL) endif() diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2104e4ac7222258ee025bd5acd60b1db251df654 --- /dev/null +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -0,0 +1,2 @@ +cc_library(benchmark SRCS benchmark.cc DEPS enforce) +cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark) diff --git a/paddle/fluid/inference/utils/benchmark.cc b/paddle/fluid/inference/utils/benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..021edc2de5e90023fcd1431dd2025450e7462bd9 --- /dev/null +++ b/paddle/fluid/inference/utils/benchmark.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/utils/benchmark.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { + +std::string Benchmark::SerializeToString() const { + std::stringstream ss; + ss << "-----------------------------------------------------\n"; + ss << "name\t"; + ss << "batch_size\t"; + ss << "num_threads\t"; + ss << "latency\t"; + ss << "qps"; + ss << '\n'; + + ss << name_ << "\t"; + ss << batch_size_ << "\t"; + ss << num_threads_ << "\t"; + ss << latency_ << "\t"; + ss << 1000 / latency_; + ss << '\n'; + return ss.str(); +} +void Benchmark::PersistToFile(const std::string &path) const { + std::ofstream file(path, std::ios::app); + PADDLE_ENFORCE(file.is_open(), "Can not open %s to add benchmark", path); + file << SerializeToString(); + file.flush(); + file.close(); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/utils/benchmark.h b/paddle/fluid/inference/utils/benchmark.h new file mode 100644 index 0000000000000000000000000000000000000000..80e8f77adb4ff2cc81a2a3dd0c44e4e304800122 --- /dev/null +++ b/paddle/fluid/inference/utils/benchmark.h @@ -0,0 +1,52 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace paddle { +namespace inference { + +/* + * Helper class to calculate the performance. + */ +struct Benchmark { + int batch_size() const { return batch_size_; } + void SetBatchSize(int x) { batch_size_ = x; } + + int num_threads() const { return num_threads_; } + void SetNumThreads(int x) { num_threads_ = x; } + + bool use_gpu() const { return use_gpu_; } + void SetUseGpu() { use_gpu_ = true; } + + int latency() const { return latency_; } + void SetLatency(int x) { latency_ = x; } + + const std::string& name() const { return name_; } + void SetName(const std::string& name) { name_ = name; } + + std::string SerializeToString() const; + void PersistToFile(const std::string& path) const; + + private: + bool use_gpu_{false}; + int batch_size_{0}; + int latency_; + int num_threads_{1}; + std::string name_; +}; + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/utils/benchmark_tester.cc b/paddle/fluid/inference/utils/benchmark_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb255474082b27180a8b3176b5f880c0d38f6c3b --- /dev/null +++ b/paddle/fluid/inference/utils/benchmark_tester.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/utils/benchmark.h" +#include +#include + +using namespace paddle::inference; +TEST(Benchmark, basic) { + Benchmark benchmark; + benchmark.SetName("key0"); + benchmark.SetBatchSize(10); + benchmark.SetUseGpu(); + benchmark.SetLatency(220); + LOG(INFO) << "benchmark:\n" << benchmark.SerializeToString(); +} + +TEST(Benchmark, PersistToFile) { + Benchmark benchmark; + benchmark.SetName("key0"); + benchmark.SetBatchSize(10); + benchmark.SetUseGpu(); + benchmark.SetLatency(220); + + benchmark.PersistToFile("1.log"); + benchmark.PersistToFile("1.log"); + benchmark.PersistToFile("1.log"); +} \ No newline at end of file diff --git a/paddle/fluid/memory/allocation/retry_allocator_test.cc b/paddle/fluid/memory/allocation/retry_allocator_test.cc index a0ce2875cb8337a59ec03730e5cf66d2fc622001..f0b215dac252475217a403e680a23559280b0e8d 100644 --- a/paddle/fluid/memory/allocation/retry_allocator_test.cc +++ b/paddle/fluid/memory/allocation/retry_allocator_test.cc @@ -41,7 +41,7 @@ TEST(RetryAllocator, RetryAllocator) { size_t thread_num = 32; size_t sleep_time = 40; - size_t extra_time = 2; + size_t extra_time = 10; // Reserve to perform more tests in the future std::vector> allocators; diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 5d468316e8eacb73c4a4ce81c784880bb5e46c2d..a0397acab1267365b8aeba30a63152b61b5b25bb 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -1,2 +1,6 @@ include(operators) -register_operators() +register_operators(EXCLUDES fusion_transpose_flatten_concat_op) +if (WITH_GPU) + op_library(fusion_transpose_flatten_concat_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_transpose_flatten_concat);\n") +endif() diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..39356c9afccbf9af3eacf99a6bccb15e18f7e485 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class TransposeFlattenConcatFusionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, + "Inputs(X) of ConcatOp should be empty."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ConcatOp should not be null."); + + auto ins = ctx->GetInputsDim("X"); + const size_t n = ins.size(); + PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0."); + + std::vector trans_axis = + ctx->Attrs().Get>("trans_axis"); + int flatten_axis = ctx->Attrs().Get("flatten_axis"); + int concat_axis = ctx->Attrs().Get("concat_axis"); + + size_t x_rank = ins[0].size(); + size_t trans_axis_size = trans_axis.size(); + PADDLE_ENFORCE_EQ(x_rank, trans_axis_size, + "The input tensor's rank(%d) " + "should be equal to the permutation axis's size(%d)", + x_rank, trans_axis_size); + + auto dims0 = + GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[0])); + std::vector out_dims(dims0); + for (size_t i = 1; i < n; i++) { + auto dimsi = + GetFlattenShape(flatten_axis, GetPermuteShape(trans_axis, ins[i])); + for (int j = 0; j < static_cast(dims0.size()); j++) { + if (j == concat_axis) { + out_dims[concat_axis] += dimsi[j]; + } else { + PADDLE_ENFORCE_EQ(out_dims[j], dimsi[j], + "After flatting, the %d-th dim should be save " + "except the specify axis.", + j); + } + } + } + if (out_dims[concat_axis] < 0) { + out_dims[concat_axis] = -1; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + } +}; + +class TransposeFlattenConcatFusionOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor) The input tensor, tensors with rank up to 6 are supported.") + .AsDuplicable(); + AddOutput("Out", "(Tensor)The output tensor."); + AddAttr>( + "trans_axis", + "(vector) A list of values, and the size of the list should be " + "the same with the input tensor rank. This operator permutes the input " + "tensor's axes according to the values given."); + AddAttr("flatten_axis", + "(int)" + "Indicate up to which input dimensions (exclusive) should be" + "flattened to the outer dimension of the output. The value" + "for axis must be in the range [0, R], where R is the rank of" + "the input tensor. When axis = 0, the shape of the output" + "tensor is (1, (d_0 X d_1 ... d_n), where the shape of the" + "input tensor is (d_0, d_1, ... d_n)."); + AddAttr("concat_axis", + "The axis along which the input tensors will be concatenated. " + "It should be 0 or 1, since the tensor is 2D after flatting."); + AddComment(R"DOC( + + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_transpose_flatten_concat, + ops::TransposeFlattenConcatFusionOp, + ops::TransposeFlattenConcatFusionOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ccb670d73c803bb1b9827f0f30b99d272bfce79 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.cu.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +template +using CudnnDataType = platform::CudnnDataType; + +template +class TransposeFlattenConcatFusionKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto odims = out->dims(); + + std::vector trans_axis = ctx.Attr>("trans_axis"); + int flatten_axis = ctx.Attr("flatten_axis"); + int concat_axis = ctx.Attr("concat_axis"); + + int rank = ins[0]->dims().size(); + // use at least 4D in cudnnTransformTensor + int max_dim = rank < 4 ? 4 : rank; + std::vector stride_x(max_dim, 0); + std::vector stride_y(max_dim, 0); + std::vector dims_y(max_dim, 0); + + cudnnTensorDescriptor_t in_desc; + cudnnTensorDescriptor_t out_desc; + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&in_desc)); + CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&out_desc)); + cudnnDataType_t cudnn_dtype = CudnnDataType::type; + + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + + T* odata = out->data(); + for (size_t k = 0; k < ins.size(); ++k) { + auto perm_shape = GetPermuteShape(trans_axis, ins[k]->dims()); + int osize = 1; + auto idims = ins[k]->dims(); + for (int i = 0; i < rank; i++) { + stride_x[i] = 1; + for (int j = trans_axis[i] + 1; j < rank; j++) { + stride_x[i] *= idims[j]; + } + dims_y[i] = perm_shape[i]; + osize *= perm_shape[i]; + } + stride_y[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + if (((i + 1) == flatten_axis) && (concat_axis == 1)) { + stride_y[i] = odims[1]; + } else { + stride_y[i] = stride_y[i + 1] * perm_shape[i + 1]; + } + } + + // Since concat is aftern flatten, the output is 2D tensor. + // If concat_axis is 0, each input's permutated tensor is continuous. + // If concat_axis is 1, the stride of 0-th dim of each input's + // permutated tensor is odims()[1]. + + for (int i = rank; i < max_dim; i++) { + stride_x[i] = 1; + stride_y[i] = 1; + dims_y[i] = 1; + } + + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + in_desc, cudnn_dtype, max_dim, dims_y.data(), stride_x.data())); + CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( + out_desc, cudnn_dtype, max_dim, dims_y.data(), stride_y.data())); + + CUDNN_ENFORCE(platform::dynload::cudnnTransformTensor( + handle, CudnnDataType::kOne(), in_desc, + static_cast(ins[k]->data()), + CudnnDataType::kZero(), out_desc, static_cast(odata))); + if (concat_axis == 0) { + odata += osize; + } else { + auto flat_shape = GetFlattenShape(flatten_axis, perm_shape); + odata += flat_shape[1]; + } + } + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(in_desc)); + CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(out_desc)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(fusion_transpose_flatten_concat, + ops::TransposeFlattenConcatFusionKernel, + ops::TransposeFlattenConcatFusionKernel); diff --git a/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h new file mode 100644 index 0000000000000000000000000000000000000000..66d5bea679fc85ce6b1ba64921107aef987ccaa8 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_transpose_flatten_concat_op.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ddim.h" + +namespace paddle { +namespace operators { + +inline std::vector GetPermuteShape(const std::vector& axis, + const framework::DDim& in_dims) { + std::vector out_dims(in_dims.size()); + for (size_t i = 0; i < axis.size(); i++) { + out_dims[i] = in_dims[axis[i]]; + } + return out_dims; +} + +inline std::vector GetFlattenShape(const int axis, + const std::vector& in_dims) { + int64_t outer = 1, inner = 1; + for (int i = 0; i < static_cast(in_dims.size()); ++i) { + if (i < axis) { + outer *= in_dims[i]; + } else { + inner *= in_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + return out_shape; +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index a6843f20a59a23bd4e875b0f96524cc8d7aa46d6..1b55527fd33e879c5c6fe702a53d0a23bebc2b14 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -67,6 +67,7 @@ class LookupSparseTableOp : public framework::OperatorBase { framework::proto::VarType::FP32, "The sparse table only support FP32"); w_t->Get(ids_t, out_t, true, is_test); + out_t->set_lod(ids_t.lod()); } }; diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 19b2c68c823adbed82319f7b04992baedd5d41f9..76cc796a9b8e21849b1d86e512cd70752fd027ac 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -127,6 +127,9 @@ class SumKernel : public framework::OpKernel { math::scatter::MergeAdd merge_add; merge_add(context.template device_context(), inputs, out); + + out->SyncIndex(); + } else { // no data, just set a empty out tensor. out->mutable_value()->mutable_data(framework::make_ddim({0}), diff --git a/paddle/fluid/operators/tensor_array_to_tensor_op.cc b/paddle/fluid/operators/tensor_array_to_tensor_op.cc index 96dc123f6a36e1a2b6ae04e0d97dffe1e10ac4ea..58a74ec2c104f66e9e884cffd00e7fa6622e4714 100644 --- a/paddle/fluid/operators/tensor_array_to_tensor_op.cc +++ b/paddle/fluid/operators/tensor_array_to_tensor_op.cc @@ -106,9 +106,9 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase { out_inx_dim[0] = inx.size(); out_inx.Resize(out_inx_dim); + auto &local_scope = scope.NewScope(); std::string var_name = "out_index"; - framework::Variable *tmp_index_var = - const_cast(scope).Var(var_name); + framework::Variable *tmp_index_var = local_scope.Var(var_name); auto &tmp_index_tensor = *(tmp_index_var->GetMutable()); tmp_index_tensor.Resize(out_inx_dim); @@ -128,12 +128,12 @@ class LoDTensorArray2TensorOp : public framework::OperatorBase { out_dims[axis] = out_dim_sum; out.Resize(out_dims); - LodTensorArray2LodTensorVector(scope, base_name, Input("X"), &names); - // Invoke Reshape Op + LodTensorArray2LodTensorVector(local_scope, base_name, Input("X"), &names); + // Invoke concat Op auto concat_op = framework::OpRegistry::CreateOp( "concat", {{"X", names}}, {{"Out", {Output("Out")}}}, attrs); - concat_op->Run(scope, place); + concat_op->Run(local_scope, place); } }; diff --git a/paddle/fluid/platform/dynload/cublas.cc b/paddle/fluid/platform/dynload/cublas.cc index 361d3439b844e9f68d3fba0a0e41ec457118a4a9..41648c32fe6f98bb0b78ea7891065e5586f70463 100644 --- a/paddle/fluid/platform/dynload/cublas.cc +++ b/paddle/fluid/platform/dynload/cublas.cc @@ -32,6 +32,9 @@ CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP); CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP); #endif +#ifdef CUBLAS_BLAS_ROUTINE_EACH_R4 +CUBLAS_BLAS_ROUTINE_EACH_R4(DEFINE_WRAP); +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index ff80bd525c167eda00f67d392c7b3b71436ee820..ced789b90d067218c3b01d124cfd2c93dc94e528 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -90,23 +90,33 @@ CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) // APIs available after CUDA 8.0 #if CUDA_VERSION >= 8000 -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmEx); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmStridedBatched); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasHgemmStridedBatched); +#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \ + __macro(cublasGemmEx); \ + __macro(cublasSgemmStridedBatched); \ + __macro(cublasDgemmStridedBatched); \ + __macro(cublasCgemmStridedBatched); \ + __macro(cublasZgemmStridedBatched); \ + __macro(cublasHgemmStridedBatched); + +CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #endif // APIs available after CUDA 9.0 #if CUDA_VERSION >= 9000 -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSetMathMode); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGetMathMode); +#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) \ + __macro(cublasSetMathMode); \ + __macro(cublasGetMathMode); + +CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #endif +// APIs available after CUDA 9.1 #if CUDA_VERSION >= 9010 -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmBatchedEx); -DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasGemmStridedBatchedEx); +#define CUBLAS_BLAS_ROUTINE_EACH_R4(__macro) \ + __macro(cublasGemmBatchedEx); \ + __macro(cublasGemmStridedBatchedEx); + +CUBLAS_BLAS_ROUTINE_EACH_R4(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) #endif #undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index babb862122a0e923809cb76a924ef5c8b621443e..ef43d13e18698748717dff35c85b243edec44592 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -31,6 +31,11 @@ int main(int argc, char** argv) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) new_argv.push_back( strdup("--tryfromenv=fraction_of_gpu_memory_to_use,allocator_strategy")); +#elif __clang__ + new_argv.push_back( + strdup("--tryfromenv=use_mkldnn,initial_cpu_memory_in_" + "mb,allocator_strategy")); + new_argv.push_back(strdup("--undefok=use_mkldnn,initial_cpu_memory_in_mb")); #else new_argv.push_back( strdup("--tryfromenv=use_pinned_memory,use_mkldnn,initial_cpu_memory_in_" diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index d851b9dfaa9fb5eeefc5aa1e03e30a56e323f758..f7fefb3e5b767e25373665058d4fd6a298fb3d60 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -91,6 +91,7 @@ def __bootstrap__(): """ import sys import os + import platform from . import core in_test = 'unittest' in sys.modules @@ -110,14 +111,17 @@ def __bootstrap__(): print('PLEASE USE OMP_NUM_THREADS WISELY.', file=sys.stderr) os.environ['OMP_NUM_THREADS'] = str(num_threads) - + sysstr = platform.system() read_env_flags = [ - 'use_pinned_memory', 'check_nan_inf', 'benchmark', 'eager_delete_scope', - 'use_mkldnn', 'use_ngraph', 'initial_cpu_memory_in_mb', - 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', - "dist_threadpool_size", 'eager_delete_tensor_gb', 'allocator_strategy', + 'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_mkldnn', + 'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem', + 'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size", + 'eager_delete_tensor_gb', 'allocator_strategy', 'reader_queue_speed_test_mode', 'print_sub_graph_dir' ] + if 'Darwin' not in sysstr: + read_env_flags.append('use_pinned_memory') + if os.name != 'nt': read_env_flags.append('warpctc_dir') read_env_flags.append('cpu_deterministic') diff --git a/python/paddle/fluid/contrib/utils/__init__.py b/python/paddle/fluid/contrib/utils/__init__.py index df6d367782327f7b22e72ab88d6b6cc26c9d5bc9..6e479bdc2b93c1189ba07a6f20b2408c34110b93 100644 --- a/python/paddle/fluid/contrib/utils/__init__.py +++ b/python/paddle/fluid/contrib/utils/__init__.py @@ -13,8 +13,10 @@ # limitations under the License. from __future__ import print_function - +from . import lookup_table_utils +from .lookup_table_utils import * from . import hdfs_utils from .hdfs_utils import * +__all__ = lookup_table_utils.__all__ __all__ = hdfs_utils.__all__ diff --git a/python/paddle/fluid/contrib/utils/lookup_table_utils.py b/python/paddle/fluid/contrib/utils/lookup_table_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2418238f98d8e2b9af0cf4290f6088c11e1b92 --- /dev/null +++ b/python/paddle/fluid/contrib/utils/lookup_table_utils.py @@ -0,0 +1,256 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import time +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid import io +from paddle.fluid import Program + +__all__ = [ + "load_inference_model", "load_persistable_vars", + "convert_dist_to_sparse_program" +] + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +_logger = logging.getLogger("lookup_table_utils") +_logger.setLevel(logging.INFO) + +model_filename = "__model__" +lookup_table_dir = "__lookup_table__" + + +def __insert_lookup_sparse_table_op(main_program, idx, ids, w, out): + main_program.global_block()._insert_op( + index=idx, + type="lookup_sparse_table", + inputs={"Ids": [ids], + "W": [w]}, + outputs={"Out": [out]}, + attrs={ + "is_distributed": False, + "is_sparse": True, + "grad_inplace": False + }) + + +def __get_prefetch_op_tuples(main_program): + # current lookup tables op is split_ids->prefetch->merge_ids + prefetch_op_tuples = None + op_types = [op.type for op in main_program.global_block().ops] + + for i in range(len(op_types)): + if op_types[i] == "prefetch": + if op_types[i - 1] == "split_ids" and op_types[i + + 1] == "merge_ids": + split_ids_op_id = i - 1 + split_ids_inputs = main_program.global_block().ops[i - 1].input( + "Ids") + prefetch_op_inputs = main_program.global_block().ops[i].input( + "X") + prefetch_op_outputs = main_program.global_block().ops[i].output( + "Out") + merge_ids_outputs = main_program.global_block().ops[ + i + 1].output("Out") + + need_delete_vars = [] + need_delete_vars.extend(prefetch_op_inputs) + need_delete_vars.extend(prefetch_op_outputs) + + prefetch_op_tuples = (split_ids_op_id, split_ids_inputs, + merge_ids_outputs, need_delete_vars) + break + return prefetch_op_tuples + + +def convert_dist_to_sparse_program(main_program): + if not main_program._distributed_lookup_table: + _logger.warn( + "There are no distributed lookup tables need to be converted") + return + + # create table param and grad var in pserver program + origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table) + emb_var = main_program._distributed_lookup_table + main_program.global_block()._rename_var(emb_var, origin_emb_var) + origin_param_var = main_program.global_block().vars[origin_emb_var] + + param_var = main_program.global_block().create_var( + name=emb_var, + shape=origin_param_var.shape, + dtype=origin_param_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) + # parameter must be selected rows + param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + main_program._sync_with_cpp() + + prefetch_op_tuples = __get_prefetch_op_tuples(main_program) + + split_ids_id = prefetch_op_tuples[0] + + for idx in range(split_ids_id + 2, split_ids_id - 1, -1): + main_program.global_block()._remove_op(idx) + main_program.desc.flush() + + in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2]) + + for in_out_pair in in_out_pairs: + idx = split_ids_id + ids = main_program.global_block().vars[in_out_pair[0]] + out = main_program.global_block().vars[in_out_pair[1]] + __insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out) + main_program.desc.flush() + return main_program + + +def load_persistable_vars(executor, dirname, program, lookup_table_var): + def _is_checkpoint_var(exclude_fluid_vars=None): + """ + the checkpoint will not save or load all the variables. + var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. + + : param var(Variable) + """ + + if exclude_fluid_vars is None: + exclude_fluid_vars = [] + + def is_valid(var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW: + return False + # @GRAD are named for gradient variables, checkpoint will not save it. + if "@GRAD" in var.name: + return False + # .trainer_ are named for distribute train variables, checkpoint will not save it. + if ".trainer_" in var.name: + return False + + # .block is named for distribute train variables, checkpoint will not save it. + if ".block" in var.name: + return False + + if "tmp_" in var.name: + return False + + if var.name in exclude_fluid_vars: + return False + + return var.persistable + + return is_valid + + def _load_lookup_table_vars(executor, dirname, main_program, + lookup_table_vars): + if not os.path.isdir(dirname): + raise ValueError("There is no directory named '%s'", dirname) + + lookup_table_dirname = os.path.join(dirname, lookup_table_dir) + + emb_var_name = lookup_table_vars[0] + emb_var = main_program.global_block().var(emb_var_name) + + emb_files = [] + for emb_name in os.listdir(lookup_table_dirname): + if emb_var_name in emb_name: + emb_files.append(emb_name) + + convert_program = Program() + global_block = convert_program.global_block() + + emb_var = global_block.create_var( + name=emb_var.name, + shape=emb_var.shape, + dtype=emb_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) + emb_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + + sums = [] + + for i, emb_file in enumerate(emb_files): + var_name = "{}_{}".format(emb_var.name, i) + param_var = global_block.create_var( + name=var_name, + shape=emb_var.shape, + dtype=emb_var.dtype, + type=core.VarDesc.VarType.SELECTED_ROWS, + persistable=True) + param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + global_block.append_op( + type='load', + inputs={}, + outputs={'Out': [param_var]}, + attrs={ + 'file_path': os.path.join(lookup_table_dirname, var_name) + }) + sums.append(param_var) + global_block.append_op( + type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={}) + global_block.append_op(type='delete_var', inputs={'X': sums}) + executor.run(convert_program) + + _logger.info("Start Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + lookup_table_vars = [lookup_table_var] + + io.load_vars( + executor, + dirname=dirname, + main_program=program, + predicate=_is_checkpoint_var(lookup_table_vars), + filename=None) + + _load_lookup_table_vars(executor, dirname, program, lookup_table_vars) + + _logger.info("Finish Load Sparse Program With " + "Distributed Lookup Table Vars from {}, time = {}".format( + dirname, time.ctime())) + + +def load_inference_model(dirname, executor, lookup_table_var_name): + if not os.path.isdir(dirname): + raise ValueError("There is no directory named '%s'", dirname) + + local_model = os.path.join(dirname, model_filename) + + with open(local_model, "rb") as f: + program_desc_str = f.read() + + program = Program.parse_from_string(program_desc_str) + + if not core._is_program_version_supported(program._version()): + raise ValueError("Unsupported program version: %d\n" % + program._version()) + + # Binary data also need version. + load_persistable_vars(executor, dirname, program, lookup_table_var_name) + + feed_target_names = program.desc.get_feed_target_names() + fetch_target_names = program.desc.get_fetch_target_names() + fetch_targets = [ + program.global_block().var(name) for name in fetch_target_names + ] + + return [program, feed_target_names, fetch_targets] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fd03dff386cad21c727ca0f266fa1b37ad65b4ad..b991187d424108db176ebd6996d7d161f11dcd3d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1698,6 +1698,7 @@ class Program(object): p._copy_param_info_from(self) p._copy_data_info_from(self) + p._copy_dist_param_info_from(self) return p def _prune(self, targets): @@ -1938,6 +1939,25 @@ class Program(object): "program, with represent the same topology") self.global_block()._copy_param_info_from(other.global_block()) + def _copy_dist_param_info_from(self, other): + """ + Copy the information of distributed information from other program. + + Args: + other(Program): Other program + + Returns: + None + """ + if not isinstance(other, Program): + raise TypeError("_copy_dist_param_info_from should be invoked with " + "Program") + self._is_distributed = other._is_distributed + self._is_chief = other._is_chief + self._slice_vars_and_attrs = other._slice_vars_and_attrs + self._endpoints = other._endpoints + self._distributed_lookup_table = other._distributed_lookup_table + def _copy_data_info_from(self, other): """ Copy the information of data variables from other program. diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 8936d884dd9e1ebbe5f688c11430b64e51ad8bd5..26d7af87b34fa03c1146f54d4753f5e1601217d6 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -165,6 +165,7 @@ def save_vars(executor, save_vars( executor, + main_program=main_program, dirname=dirname, vars=list(filter(predicate, main_program.list_vars())), filename=filename) @@ -172,11 +173,18 @@ def save_vars(executor, save_program = Program() save_block = save_program.global_block() + if main_program is None: + main_program = default_main_program() + if not isinstance(main_program, Program): + raise TypeError("program should be as Program type or None") + save_var_map = {} for each_var in vars: # NOTE: don't save the variable which type is RAW if each_var.type == core.VarDesc.VarType.RAW: continue + if each_var.name == main_program._distributed_lookup_table: + continue new_var = _clone_var_in_block_(save_block, each_var) if filename is None: save_block.append_op( @@ -198,6 +206,16 @@ def save_vars(executor, outputs={}, attrs={'file_path': os.path.join(dirname, filename)}) + # if there is lookup table, the trainer 0 will notify all pserver to save. + if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: + lookup_table_filename = os.path.join(dirname, "__lookup_table__") + attrs = {} + attrs['epmap'] = main_program._endpoints + attrs['dir'] = lookup_table_filename + attrs['lookup_table'] = main_program._distributed_lookup_table + save_block.append_op( + type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs) + executor.run(save_program) @@ -379,11 +397,22 @@ def load_vars(executor, load_prog = Program() load_block = load_prog.global_block() + if main_program is None: + main_program = default_main_program() + if not isinstance(main_program, Program): + raise TypeError("program should be as Program type or None") + + load_slice_vars = [] + for each_var in main_program._slice_vars_and_attrs: + load_slice_vars.append(each_var[2].name) + load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) if each_var.type == core.VarDesc.VarType.RAW: continue + if each_var.name in load_slice_vars: + continue new_var = _clone_var_in_block_(load_block, each_var) if filename is None: load_block.append_op( @@ -406,9 +435,6 @@ def load_vars(executor, attrs={'file_path': os.path.join(dirname, filename)}) executor.run(load_prog) - if main_program is None: - main_program = default_main_program() - # load slice vars on pserver, if have it. _load_slice_up_vars(executor, dirname, main_program._slice_vars_and_attrs) @@ -618,13 +644,6 @@ def save_inference_model(dirname, if main_program is None: main_program = default_main_program() - # if there is lookup table, the trainer 0 will notify all pserver to save. - if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table: - lookup_table_filename = os.path.join(dirname, "__lookup_table__") - _save_lookup_tables_by_notify(executor, lookup_table_filename, - main_program._distributed_lookup_table, - main_program._endpoints) - # when a pserver and a trainer running on the same machine, mkdir may conflict try: os.makedirs(dirname) @@ -642,6 +661,9 @@ def save_inference_model(dirname, # it can only be loaded for inference directly. If it's false, the whole # original program and related meta are saved so that future usage can be # more flexible. + + origin_program = main_program.clone() + if export_for_deployment: main_program = main_program.clone() global_block = main_program.global_block() @@ -666,8 +688,11 @@ def save_inference_model(dirname, with open(model_basename + ".main_program", "wb") as f: f.write(main_program.desc.serialize_to_string()) + main_program._copy_dist_param_info_from(origin_program) + if params_filename is not None: params_filename = os.path.basename(params_filename) + save_persistables(executor, dirname, main_program, params_filename) @@ -897,6 +922,9 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): slice_var = var_tuple[2] end = start + slice_var.shape[0] + orig_var_name = orig_var.name + orig_var.name = "{}.origin".format(orig_var_name) + clone_orig_var = load_block.create_var( name=orig_var.name, type=orig_var.type, @@ -915,7 +943,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): type='load', inputs={}, outputs={'Out': [clone_orig_var]}, - attrs={'file_path': os.path.join(dirname, clone_orig_var.name)}) + attrs={'file_path': os.path.join(dirname, orig_var_name)}) load_block.append_op( type="slice", inputs={'Input': clone_orig_var}, @@ -924,6 +952,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs): 'starts': [start], 'ends': [end]}) need_delete_vars.append(clone_orig_var) + load_block.append_op( type='delete_var', inputs={'X': need_delete_vars}, ) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 9730fbf510cbe8c323b761b29821710f2c14a81d..05138bf94598f649ef7fdbaa94896b6ba0884416 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -896,9 +896,10 @@ def array_to_lod_tensor(x, table): def increment(x, value=1.0, in_place=True): """ - This function performs an operation that increments each value in the + This function performs an operation that increments the value in the input :math:`x` by an amount: :math:`value` as mentioned in the input - parameter. This operation is performed in-place by default. + parameter. This operation is performed in-place by default. Notice that + the number of elements in :math:`x` must be equal to 1. Args: x (Variable|list): The tensor that has the input values. @@ -911,7 +912,8 @@ def increment(x, value=1.0, in_place=True): Examples: .. code-block:: python - data = fluid.layers.data(name='data', shape=[32, 32], dtype='float32') + data = fluid.layers.data(name='data', shape=[1], dtype='float32', + append_batch_size=False) data = fluid.layers.increment(x=data, value=3.0, in_place=True) """ helper = LayerHelper("increment", **locals()) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6d0e0ea240f758725b5b368edb7f47753ebbaf07..7af1f380e701f867921a16d9f0a91bcfad5e23ea 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6972,18 +6972,18 @@ def prelu(x, mode, param_attr=None, name=None): """ Equation: - y = \max(0, x) + alpha \min(0, x) + y = \max(0, x) + alpha * \min(0, x) Args: x (Variable): The input tensor. - param_attr(ParamAttr|None): The parameter attribute for the learnable - weight (alpha). - mode (string): The mode for weight sharing - all: all elements share same weight - channel:elements in a channel share same weight - element:each element has a weight - name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight (alpha). + mode (string): The mode for weight sharing. It supports all, channel + and element. all: all elements share same weight + channel:elements in a channel share same weight + element:each element has a weight + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. Returns: Variable: The output tensor with the same shape as input. @@ -6992,7 +6992,7 @@ def prelu(x, mode, param_attr=None, name=None): .. code-block:: python - x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") + x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") mode = 'channel' output = fluid.layers.prelu(x,mode) """ diff --git a/python/paddle/fluid/tests/unittests/test_fusion_transpose_flatten_concat_op.py b/python/paddle/fluid/tests/unittests/test_fusion_transpose_flatten_concat_op.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa7f76495abc03646ced1f183731f30d50c4223 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fusion_transpose_flatten_concat_op.py @@ -0,0 +1,105 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +class TestFusionTransposeFlattenConcationOp(OpTest): + def setUp(self): + self.init_test_case() + self.op_type = "fusion_transpose_flatten_concat" + + ins = [] + flats = [] + for i in range(len(self.shapes)): + in_shape = self.shapes[i] + a = np.random.random(in_shape).astype("float32") + ins.append(("x%d" % i, a)) + + b = a.transpose(self.trans_axis) + flat_shape = (np.prod(b.shape[:self.flatten_axis]), + np.prod(b.shape[self.flatten_axis:])) + c = b.reshape(flat_shape) + flats.append(c) + out = np.concatenate(flats, axis=self.concat_axis) + + self.inputs = {'X': ins} + self.attrs = { + 'trans_axis': list(self.trans_axis), + 'flatten_axis': self.flatten_axis, + 'concat_axis': self.concat_axis + } + self.outputs = {'Out': out} + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + self.check_output_with_place(place, 1e-6) + else: + pass + + def init_test_case(self): + self.shapes = [(3, 4, 17, 17), (3, 8, 7, 7), (3, 12, 5, 5)] + self.trans_axis = (0, 2, 3, 1) + self.flatten_axis = 1 + self.concat_axis = 1 + + +class TestCase1(TestFusionTransposeFlattenConcationOp): + def init_test_case(self): + self.shapes = [(3, 4, 18, 17), (3, 8, 18, 7), (6, 12, 9, 5)] + self.trans_axis = (0, 2, 3, 1) + self.flatten_axis = 2 + self.concat_axis = 1 + + +class TestCase2(TestFusionTransposeFlattenConcationOp): + def init_test_case(self): + self.shapes = [(3, 8, 20, 17), (3, 8, 19, 17), (3, 8, 40, 17)] + self.trans_axis = (0, 2, 3, 1) + self.flatten_axis = 2 + self.concat_axis = 0 + + +class TestCase3(TestFusionTransposeFlattenConcationOp): + def init_test_case(self): + self.shapes = [(3, 8, 20, 17), (3, 8, 19, 17), (3, 8, 40, 17)] + self.trans_axis = (0, 3, 2, 1) + self.flatten_axis = 1 + self.concat_axis = 1 + + +class TestCase4(TestFusionTransposeFlattenConcationOp): + def init_test_case(self): + self.shapes = [(3, 8, 9, 17), (8, 3, 9, 17), (4, 6, 9, 17)] + self.trans_axis = (0, 2, 1, 3) + self.flatten_axis = 3 + self.concat_axis = 1 + + +class TestCase5(TestFusionTransposeFlattenConcationOp): + def init_test_case(self): + self.shapes = [(3, 8, 9, 17, 2), (3, 8, 2, 17, 9), (3, 17, 9, 8, 2)] + self.trans_axis = (0, 2, 1, 4, 3) + self.flatten_axis = 1 + self.concat_axis = 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 89bc24802751340b6d4657be8673d714f3d3dc2b..ebd0d18d36eed4fffed86ba0903ff76f6052ef7a 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -644,6 +644,9 @@ in a single call.") else: recv_inputs.append(single_trainer_var) + self._slice_params_and_optimizes = self._get_slice_vars_and_attrs( + endpoint) + # step 3 # Create a union-find data structure from optimize ops, # If two ops are connected, we could add these two ops @@ -766,7 +769,7 @@ in a single call.") grad_to_block_id, merged_var, lr_ops) -# dedup grad to ids list + # dedup grad to ids list grad_to_block_id = list(set(grad_to_block_id)) # append global ops if global_ops: @@ -827,8 +830,8 @@ in a single call.") attrs=attrs) # add distributed attrs - pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs( - endpoint) + pserver_program._slice_vars_and_attrs = list( + self._slice_params_and_optimizes.values()) pserver_program._sync_with_cpp() # save pserver program to generate pserver side startup relatively. @@ -941,12 +944,12 @@ to transpile() call.") outputs={"Out": startup_tmpvar}) # add slice vars - s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint) + s_prog._slice_vars_and_attrs = pserver_program._slice_vars_and_attrs return s_prog def _get_slice_vars_and_attrs(self, endpoint): - slice_vars_and_attrs = [] + slice_vars_and_attrs = {} block_suffix = "block" for param in self.param_grad_ep_mapping[endpoint]["params"]: orig_var_name, block_name, _ = self._get_varname_parts(param.name) @@ -960,8 +963,7 @@ to transpile() call.") slice_vars = self.param_var_mapping[orig_var_name] for slice_var in slice_vars[:block_idx]: skip_dim0 += slice_var.shape[0] - slice_vars_and_attrs.append([orig_var, skip_dim0, param]) - + slice_vars_and_attrs[param.name] = [orig_var, skip_dim0, param] return slice_vars_and_attrs # ====================== private transpiler functions ===================== @@ -1662,10 +1664,10 @@ to transpile() call.") if key in ["Param", "Grad", "LearningRate"]: continue var = self.origin_program.global_block().vars[opt_op.input(key)[0]] + param_var = new_inputs["Param"] # update accumulator variable shape - param_shape = new_inputs["Param"].shape - new_shape = self._get_optimizer_input_shape(opt_op.type, key, - var.shape, param_shape) + new_shape = self._get_optimizer_input_shape( + opt_op.type, key, var.shape, param_var.shape) tmpvar = pserver_block.create_var( name=var.name, persistable=var.persistable, @@ -1673,6 +1675,13 @@ to transpile() call.") shape=new_shape) new_inputs[key] = tmpvar + # var shape been changed + if new_shape != var.shape: + slice_var_args = self._slice_params_and_optimizes[ + param_var.name] + self._slice_params_and_optimizes[ + var.name] = [var, slice_var_args[1], tmpvar] + # change output's ParamOut variable outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op)