diff --git a/paddle/fluid/framework/details/execution_strategy.h b/paddle/fluid/framework/details/execution_strategy.h index 5183be878eb49cccc68603c3fdd8023be5578036..15c496130c2b6c7643ff96661be09e5ac4870344 100644 --- a/paddle/fluid/framework/details/execution_strategy.h +++ b/paddle/fluid/framework/details/execution_strategy.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include // for size_t namespace paddle { namespace framework { @@ -26,6 +27,7 @@ struct ExecutionStrategy { bool allow_op_delay_{false}; size_t num_iteration_per_drop_scope_{100}; ExecutorType type_{kDefault}; + bool dry_run_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 98fc390e72fab3701538fd6f974460fa5114fdb0..2b2329b9698908fdbe3385f1d555d756c47fc5c0 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -128,7 +128,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( size_t complete = 0; while (op_to_run != nullptr) { try { - op_to_run->Run(strategy_.use_cuda_); + if (LIKELY(!strategy_.dry_run_)) { + op_to_run->Run(strategy_.use_cuda_); + } ++complete; } catch (...) { exception_.Catch(std::current_exception()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index dc63effd1b7c8fe5bb3fc91058eb855e552d3926..2d2bdb604f2d08adbaa0b38d04b8e377b2e6ab6c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -211,7 +211,9 @@ void ThreadedSSAGraphExecutor::RunOp( if (VLOG_IS_ON(10)) { VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); } - op->Run(strategy_.use_cuda_); + if (LIKELY(!strategy_.dry_run_)) { + op->Run(strategy_.use_cuda_); + } VLOG(10) << op << " " << op->Name() << " Done "; running_ops_--; ready_var_q->Extend(op->Outputs()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index dbb0b498d995a897b109bd4ef98521b2193276ed..5c0bc169eaf3f54596eb8e08b7bf80a82253c9b2 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -48,7 +48,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { // Use topological sort algorithm FeedFetchList Run(const std::vector &fetch_tensors) override; - ~ThreadedSSAGraphExecutor() {} + ~ThreadedSSAGraphExecutor() final = default; private: void RunOp(const std::shared_ptr> &ready_var_q, diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index a45b9ec7a20ac3629d182f009b735d4d82fb5dc2..dfb107688ad7281765049cd9849d56b8a61bdd37 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -38,9 +38,20 @@ class ParallelExecutorPrivate { explicit ParallelExecutorPrivate(const std::vector &places) : places_(places) {} + ~ParallelExecutorPrivate() { + if (own_local_scope_) { + for (size_t i = 1; i < local_scopes_.size(); ++i) { + // Skip the first scope, since it is the global scope. + Scope *local_scope = local_scopes_[i]; + if (global_scope_->HasKid(local_scope)) { + global_scope_->DeleteScope(local_scope); + } + } + } + } std::vector places_; std::vector local_scopes_; - Scope *global_scope_; + Scope *global_scope_; // not owned std::unique_ptr executor_; #ifdef PADDLE_WITH_CUDA @@ -306,16 +317,6 @@ ParallelExecutor::~ParallelExecutor() { for (auto &p : member_->places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } - - if (member_->own_local_scope_) { - for (size_t i = 1; i < member_->local_scopes_.size(); ++i) { - Scope *local_scope = member_->local_scopes_[i]; - if (member_->global_scope_->HasKid(local_scope)) { - member_->global_scope_->DeleteScope(local_scope); - } - } - } - // member_ must be destructed before gcs_ since the destructor of // ReferenceCountOpHandle use raw pointers of gcs_ inside. member_.reset(); diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index d31c8e3b7d66a0cdb2c4725783c9a24f049c666d..e5678cf607a8ff3763e79c1f321a81c021846fb1 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -1,5 +1,5 @@ if(WITH_TESTING) - include(test.cmake) # some generic cmake funtion for inference + include(tests/test.cmake) # some generic cmake funtion for inference endif() # analysis and tensorrt must be added before creating static library, # otherwise, there would be undefined reference to them in static library. diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 0a37d3968c39d2c244bbd82161afddf6330e421d..7bcf2dd1eeb17e802c5647df31945284ae08fa95 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -18,6 +18,21 @@ namespace paddle { namespace inference { namespace tensorrt { +bool to_skip_merging_optimize(TensorRTEngine* engine_, + const std::vector& filters, + const std::vector& strides, + const std::vector& paddings, + std::string input_name) { + if (engine_->itensor_quote_num[input_name] > 0) { + return true; + } + if (filters[0] == 1 && filters[1] == 1 && strides[0] == 1 && + strides[1] == 1 && paddings[0] == 0 && paddings[1] == 0) + engine_->itensor_quote_num[input_name] += 1; + + return false; +} + class Conv2dOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, @@ -31,6 +46,7 @@ class Conv2dOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(op_desc.Output("Output").size(), 1); auto* X = engine_->GetITensor(op_desc.Input("Input").front()); + // Declare weights auto* Y_v = scope.FindVar(op_desc.Input("Filter").front()); PADDLE_ENFORCE_NOT_NULL(Y_v); @@ -83,7 +99,10 @@ class Conv2dOpConverter : public OpConverter { std::move(weight_tensor); layer->getOutput(0)->setName(output_name.c_str()); engine_->SetITensor(output_name, layer->getOutput(0)); - if (test_mode) { + + if (test_mode || + to_skip_merging_optimize(engine_, {filter_h, filter_w}, strides, + paddings, op_desc.Input("Input").front())) { engine_->DeclareOutput(output_name); } } diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 14e9e14d33d637ee68e37593cc48721e5169499f..9e0f95844761db7571c5313726d34685a9aa66b2 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -133,6 +133,10 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset, buffer_sizes_[name] = 0; } +bool TensorRTEngine::HasDeclared(const std::string &name) { + return buffer_sizes_.count(name) > 0; +} + void TensorRTEngine::DeclareOutput(const std::string &name) { PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", name); diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index bd3ba4cea6551a7f6651e311e2649de191a6faa1..d9d3827321127631c0af6e5cfd2dfdd640cee146 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -91,6 +91,8 @@ class TensorRTEngine : public EngineBase { const std::string& name); // Set the itensor_map_[name] as the network's output, and set its name. void DeclareOutput(const std::string& name); + // Check if the ITensor has been declared + bool HasDeclared(const std::string& name); // GPU memory address for an ITensor with specific name. One can operate on // these memory directly for acceleration, for example, output the converted @@ -132,6 +134,16 @@ class TensorRTEngine : public EngineBase { std::unordered_map> weight_map; + // TODO: (NHZLX) + // In the normal case, the paddle-trt exists bug when runing the googlenet. + // When there are more than two convolutions of 1 * 1 with the same input, the + // paddle-tensorrt will do the merging optimization, which fuse those conv + // into + // one conv, and then trigger bug. So, We should use strategy to avoid this + // optimization for the time being. This bug will be fixed in the future. + std::unordered_map + itensor_quote_num; + private: // the max batch size int max_batch_; diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index b57a26b47026d1ecffab23b65c3eeb7de58f94eb..2ca84c80058b35840aff5d072cdc99ecf5165f8e 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,5 +1,11 @@ set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) +function(download_model install_dir model_name) + if (NOT EXISTS ${install_dir}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) + endif() +endfunction() + function(download_model_and_data install_dir model_name data_name) if (NOT EXISTS ${install_dir}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL} ${model_name}) @@ -13,6 +19,13 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) endfunction() +function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) + download_model(${install_dir} ${model_name}) + inference_analysis_test(${target} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${install_dir}/model) +endfunction() + # RNN1 if(NOT APPLE) set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") @@ -61,17 +74,13 @@ inference_analysis_api_test(test_analyzer_seq_conv1 ${SEQ_CONV1_INSTALL_DIR} ana # ocr set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") if (NOT EXISTS ${OCR_INSTALL_DIR}) - inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") + inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.cdn.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") endif() inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) # resnet50 -set(RESNET50_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/resnet50") -if (NOT EXISTS ${RESNET50_INSTALL_DIR}) - inference_download_and_uncompress(${RESNET50_INSTALL_DIR} ${INFERENCE_URL} "resnet50_model.tar.gz") -endif() -inference_analysis_test(test_analyzer_resnet50 SRCS analyzer_resnet50_tester.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${RESNET50_INSTALL_DIR}/model) +inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 + "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz") # anakin if (WITH_ANAKIN AND WITH_MKL) # only needed in CI diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index c2151eea0823f80feb17b014c1f739d2a15ae862..e5c8dfd22a006d5271248c5b083aab2c22417502 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -30,25 +30,7 @@ void SetConfig(AnalysisConfig *cfg) { } void SetInput(std::vector> *inputs) { - PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); - - PaddleTensor input; - // channel=3, height/width=318 - std::vector shape({FLAGS_batch_size, 3, 318, 318}); - input.shape = shape; - input.dtype = PaddleDType::FLOAT32; - - // fill input data, for profile easily, do not use random data here. - size_t size = FLAGS_batch_size * 3 * 318 * 318; - input.data.Resize(size * sizeof(float)); - float *input_data = static_cast(input.data.data()); - for (size_t i = 0; i < size; i++) { - *(input_data + i) = static_cast(i) / size; - } - - std::vector input_slots; - input_slots.assign({input}); - (*inputs).emplace_back(input_slots); + SetFakeImageInput(inputs, FLAGS_infer_model); } // Easy for profiling independently. @@ -61,13 +43,6 @@ void profile(bool use_mkldnn = false) { std::vector> input_slots_all; SetInput(&input_slots_all); TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads); - - if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { - PADDLE_ENFORCE_EQ(outputs.size(), 1UL); - size_t size = GetSize(outputs[0]); - // output is a 512-dimension feature - EXPECT_EQ(size, 512 * FLAGS_batch_size); - } } TEST(Analyzer_resnet50, profile) { profile(); } @@ -83,8 +58,7 @@ TEST(Analyzer_resnet50, fuse_statis) { auto predictor = CreatePaddlePredictor(cfg); auto fuse_statis = GetFuseStatis( static_cast(predictor.get()), &num_ops); - ASSERT_TRUE(fuse_statis.count("fc_fuse")); - EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); + LOG(INFO) << "num_ops: " << num_ops; } // Compare result of NativeConfig and AnalysisConfig diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 19c3f532d5dcb7588793fa21fa179f6b48649103..8c5888d8da7b33eeca77311c10dd818648e8e524 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -25,6 +25,7 @@ #include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h" +#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/platform/profiler.h" DEFINE_string(infer_model, "", "model path"); @@ -105,6 +106,34 @@ std::unordered_map GetFuseStatis(PaddlePredictor *predictor, return fuse_statis; } +void SetFakeImageInput(std::vector> *inputs, + const std::string &dirname) { + // Set fake_image_data + PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); + std::vector> feed_target_shapes = + GetFeedTargetShapes(dirname, true, "model", "params"); + int dim1 = feed_target_shapes[0][1]; + int dim2 = feed_target_shapes[0][2]; + int dim3 = feed_target_shapes[0][3]; + + PaddleTensor input; + std::vector shape({FLAGS_batch_size, dim1, dim2, dim3}); + input.shape = shape; + input.dtype = PaddleDType::FLOAT32; + + // fill input data, for profile easily, do not use random data here. + size_t size = FLAGS_batch_size * dim1 * dim2 * dim3; + input.data.Resize(size * sizeof(float)); + float *input_data = static_cast(input.data.data()); + for (size_t i = 0; i < size; i++) { + *(input_data + i) = static_cast(i) / size; + } + + std::vector input_slots; + input_slots.assign({input}); + (*inputs).emplace_back(input_slots); +} + void TestOneThreadPrediction( const AnalysisConfig &config, const std::vector> &inputs, diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index 91111f2af56065bbf57ba3a41bddd55ecced1060..75840a9c437d956da4f542a38b2532ea20ee96c5 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -93,11 +93,16 @@ void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) { } } -TEST(trt_models_test, main) { - std::vector infer_models = {"mobilenet", "resnet50", - "resnext50"}; - for (auto &model_dir : infer_models) { - CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + model_dir); - } +TEST(trt_models_test, mobilenet) { + CompareTensorRTWithFluid(1, FLAGS_dirname + "/mobilenet"); +} + +TEST(trt_models_test, resnet50) { + CompareTensorRTWithFluid(1, FLAGS_dirname + "/resnet50"); } + +TEST(trt_models_test, resnext50) { + CompareTensorRTWithFluid(1, FLAGS_dirname + "/resnext50"); +} + } // namespace paddle diff --git a/paddle/fluid/inference/test.cmake b/paddle/fluid/inference/tests/test.cmake similarity index 100% rename from paddle/fluid/inference/test.cmake rename to paddle/fluid/inference/tests/test.cmake diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 94f0550df57e79fa68c135f5c9c4b7effe6ac156..2118fcfd4bb1589947617e462f09971fcc090b98 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/framework/ir/graph_to_program_pass.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/profiler.h" @@ -94,15 +93,15 @@ void CheckError(const paddle::framework::LoDTensor& output1, std::unique_ptr InitProgram( paddle::framework::Executor* executor, paddle::framework::Scope* scope, - const std::string& dirname, const bool is_combined = false) { + const std::string& dirname, const bool is_combined = false, + const std::string& prog_filename = "__model_combined__", + const std::string& param_filename = "__params_combined__") { std::unique_ptr inference_program; if (is_combined) { // All parameters are saved in a single file. // Hard-coding the file names of program and parameters in unittest. // The file names should be consistent with that used in Python API // `fluid.io.save_inference_model`. - std::string prog_filename = "__model_combined__"; - std::string param_filename = "__params_combined__"; inference_program = paddle::inference::Load(executor, scope, dirname + "/" + prog_filename, dirname + "/" + param_filename); @@ -115,12 +114,15 @@ std::unique_ptr InitProgram( } std::vector> GetFeedTargetShapes( - const std::string& dirname, const bool is_combined = false) { + const std::string& dirname, const bool is_combined = false, + const std::string& prog_filename = "__model_combined__", + const std::string& param_filename = "__params_combined__") { auto place = paddle::platform::CPUPlace(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); - auto inference_program = InitProgram(&executor, scope, dirname, is_combined); + auto inference_program = InitProgram(&executor, scope, dirname, is_combined, + prog_filename, param_filename); auto& global_block = inference_program->Block(0); const std::vector& feed_target_names = @@ -136,15 +138,6 @@ std::vector> GetFeedTargetShapes( return feed_target_shapes; } -void Compile(paddle::framework::ProgramDesc* program) { - std::unique_ptr g( - new paddle::framework::ir::Graph(*program)); - auto pass = paddle::framework::ir::PassRegistry::Instance().Get( - "graph_to_program_pass"); - pass->SetNotOwned("program", program); - pass->Apply(std::move(g)); -} - template void TestInference(const std::string& dirname, const std::vector& cpu_feeds, @@ -182,7 +175,6 @@ void TestInference(const std::string& dirname, paddle::platform::DeviceContextPool::Instance().Get(place)); inference_program = InitProgram(&executor, scope, dirname, is_combined); } - Compile(inference_program.get()); // Disable the profiler and print the timing information paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, @@ -261,5 +253,3 @@ void TestInference(const std::string& dirname, delete scope; } - -USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/operators/distributed/grpc_variable_response.cc b/paddle/fluid/operators/distributed/grpc_variable_response.cc index 9e54aafb2d2ecb38cfcd9e1cc5242a56fe4ddc8b..d6d219d4369ba785e5c369538d4a18dc682952c1 100644 --- a/paddle/fluid/operators/distributed/grpc_variable_response.cc +++ b/paddle/fluid/operators/distributed/grpc_variable_response.cc @@ -286,10 +286,10 @@ int GRPCVariableResponse::Parse(Source* source) { platform::EnableProfiler(platform::ProfilerState::kCPU); } else if (profiling == platform::kDisableProfiler && platform::IsProfileEnabled()) { - // TODO(panyx0718): Should we allow to customize file dir. platform::DisableProfiler( platform::EventSortingKey::kDefault, - string::Sprintf("/tmp/profile_ps_%lld", listener_id)); + string::Sprintf("%s_%lld", FLAGS_rpc_server_profile_path, + listener_id)); } break; } diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 40143887e510532cb8b0fb4a82aae3cbf7bfd320..025528fe70b8f4d353ab92f29b1bd71c77cf7850 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -51,7 +51,6 @@ bool RequestSendHandler::Handle(const std::string& varname, // Async if (!sync_mode_) { VLOG(3) << "async process var: " << varname; - rpc_server_->Profiler().OneStep(); try { executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), scope); diff --git a/paddle/fluid/operators/distributed/rpc_server.cc b/paddle/fluid/operators/distributed/rpc_server.cc index 084480ae48b8b9267ade1a840f6a70519cb28e48..3e30ed4ac86bd2cb3f7c4301163e54a947c3d5b4 100644 --- a/paddle/fluid/operators/distributed/rpc_server.cc +++ b/paddle/fluid/operators/distributed/rpc_server.cc @@ -20,42 +20,10 @@ #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/platform/profiler.h" -DEFINE_int32(rpc_server_profile_period, 0, - "the period of listen_and_serv to do profile"); -DEFINE_string(rpc_server_profile_path, "/dev/null", - "the profile log file path"); - namespace paddle { namespace operators { namespace distributed { -RPCServerProfiler::RPCServerProfiler(int profile_period, - const std::string& profile_log_path) - : profile_period_(profile_period), profile_log_path_(profile_log_path) { - step_ = 0; -} - -void RPCServerProfiler::OneStep() { - PADDLE_ENFORCE_LE(step_, profile_period_, - "step_ should not be larger then " - "profile_period_"); - if (profile_period_ <= 0) { - return; - } - - if (step_ == 0) { - auto pf_state = paddle::platform::ProfilerState::kCPU; - paddle::platform::EnableProfiler(pf_state); - } - if (step_ == profile_period_) { - paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kTotal, - profile_log_path_); - step_ = 0; - } else { - step_++; - } -} - void RPCServer::ShutDown() { LOG(INFO) << "RPCServer ShutDown "; ShutDownImpl(); diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index f3e61e1575ced0b9ffbad23e6973121daca9751b..c78c5007a7f262f15305b6c284e8c4fbddef42a0 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -23,30 +23,14 @@ #include "paddle/fluid/operators/distributed/request_handler.h" -DECLARE_int32(rpc_server_profile_period); -DECLARE_string(rpc_server_profile_path); - namespace paddle { namespace operators { namespace distributed { -class RPCServerProfiler { - public: - RPCServerProfiler(int profile_period, const std::string& profile_log_path); - void OneStep(); - - private: - const int profile_period_; - std::string profile_log_path_; - int step_; -}; - class RPCServer { public: explicit RPCServer(const std::string& address, int client_num) : cur_cond_(0), - profiler_(FLAGS_rpc_server_profile_period, - FLAGS_rpc_server_profile_path), bind_address_(address), exit_flag_(false), selected_port_(0), @@ -86,7 +70,6 @@ class RPCServer { void Complete(); void ResetBarrierCounter(); - RPCServerProfiler& Profiler() { return profiler_; } bool NeedResetAllVars(); @@ -101,7 +84,6 @@ class RPCServer { std::unordered_map rpc_cond_map_; std::atomic cur_cond_; std::condition_variable rpc_cond_; - RPCServerProfiler profiler_; protected: std::string bind_address_; diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index c4854d50b6371064003a10e18efc9e5f160d9a42..b2f73b67dc9bf944892187abd2e5709e54449d7d 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -16,6 +16,9 @@ #include #include "paddle/fluid/operators/distributed/sendrecvop_utils.h" +DEFINE_string(rpc_server_profile_path, "./profile_ps", + "the profile log file path"); + namespace paddle { namespace operators { namespace distributed { diff --git a/paddle/fluid/operators/distributed/variable_response.h b/paddle/fluid/operators/distributed/variable_response.h index f20a6038cefc28ff0569e2523cf77ddd172aa4e8..4c7fcbbdfb305ce6b4fc9d1edd9738899b200ec6 100644 --- a/paddle/fluid/operators/distributed/variable_response.h +++ b/paddle/fluid/operators/distributed/variable_response.h @@ -27,6 +27,8 @@ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h" +DECLARE_string(rpc_server_profile_path); + namespace paddle { namespace operators { namespace distributed { diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 865799589c4a5525f845516022171a9117fa0bac..1d8b1411cddf4fe16d2d00313c519cc173e1504d 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -134,7 +134,6 @@ void ListenAndServOp::RunSyncLoop( rpc_service_->ResetBarrierCounter(); while (true) { - rpc_service_->Profiler().OneStep(); // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. rpc_service_->SetCond(distributed::kRequestSend); diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index a651e0265a0ebfaca50214aa5a59f674a18cf30c..cb200ec8d6ea533d546f3e01a16a48c88b14f677 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -28,7 +28,7 @@ __device__ __forceinline__ double real_log(double x) { return log(x); } __device__ __forceinline__ platform::float16 real_log( const platform::float16& val) { - return static_cast(hlog(static_cast(val))); + return static_cast(logf(static_cast(val))); } template diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h index 87220d4019fc9337fb8355172ca4f1372cfd4558..b072b4c20a171d148bd892c162436d03da404fb9 100644 --- a/paddle/fluid/operators/math/fc_compute.h +++ b/paddle/fluid/operators/math/fc_compute.h @@ -36,7 +36,7 @@ inline void FCCompute(const BlasT& blas, const int M, .template Get>(N); for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vaddrelu->Compute(B, dst, dst); + vaddrelu->Compute(B, dst, dst, N); } } else { const auto& vadd = jitkernel::KernelPool::Instance() @@ -47,7 +47,7 @@ inline void FCCompute(const BlasT& blas, const int M, #endif for (int i = 0; i < M; i++) { T* dst = Y + i * N; - vadd->Compute(B, dst, dst); + vadd->Compute(B, dst, dst, N); } } } diff --git a/paddle/fluid/operators/math/jit_code.cc b/paddle/fluid/operators/math/jit_code.cc index 9e2cc18c7a5e396be40b2336382f68a17f8a2bf9..a92e5d351e71a55bca2845ce275780950d096031 100644 --- a/paddle/fluid/operators/math/jit_code.cc +++ b/paddle/fluid/operators/math/jit_code.cc @@ -24,19 +24,29 @@ namespace gen { using namespace platform::jit; // NOLINT -bool VMulJitCode::init(int d) { +bool VVVJitCode::init(int d) { // It's not necessary to use avx512 since it would slow down the frequency // and this kernel is not compute bound. return MayIUse(avx); } -void VMulJitCode::generate() { +void VVVJitCode::generate() { // do not need push stack, and do not need save avx512reg if do not use avx512 int offset = 0; + if (with_relu_) { + vxorps(ymm_zero, ymm_zero, ymm_zero); + } for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) { vmovups(ymm_src1, ptr[param1 + offset]); vmovups(ymm_src2, ptr[param2 + offset]); - vmulps(ymm_dst, ymm_src1, ymm_src2); + if (type_ == operand_type::mul) { + vmulps(ymm_dst, ymm_src1, ymm_src2); + } else if (type_ == operand_type::add) { + vaddps(ymm_dst, ymm_src1, ymm_src2); + } + if (with_relu_) { + vmaxps(ymm_dst, ymm_zero, ymm_dst); + } vmovups(ptr[param3 + offset], ymm_dst); offset += sizeof(float) * AVX_FLOAT_BLOCK; } @@ -44,7 +54,14 @@ void VMulJitCode::generate() { if (rest >= 4) { vmovups(xmm_src1, ptr[param1 + offset]); vmovups(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulps(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddps(xmm_dst, xmm_src1, xmm_src2); + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovups(ptr[param3 + offset], xmm_dst); offset += sizeof(float) * 4; rest -= 4; @@ -52,7 +69,14 @@ void VMulJitCode::generate() { if (rest >= 2) { vmovq(xmm_src1, ptr[param1 + offset]); vmovq(xmm_src2, ptr[param2 + offset]); - vmulps(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulps(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddps(xmm_dst, xmm_src1, xmm_src2); + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovq(ptr[param3 + offset], xmm_dst); offset += sizeof(float) * 2; rest -= 2; @@ -60,12 +84,18 @@ void VMulJitCode::generate() { if (rest > 0) { vmovss(xmm_src1, ptr[param1 + offset]); vmovss(xmm_src2, ptr[param2 + offset]); - vmulss(xmm_dst, xmm_src1, xmm_src2); + if (type_ == operand_type::mul) { + vmulss(xmm_dst, xmm_src1, xmm_src2); + } else if (type_ == operand_type::add) { + vaddss(xmm_dst, xmm_src1, xmm_src2); + } + if (with_relu_) { + vmaxps(xmm_dst, xmm_zero, xmm_dst); + } vmovss(ptr[param3 + offset], xmm_dst); } ret(); } - } // namespace gen } // namespace jitkernel } // namespace math diff --git a/paddle/fluid/operators/math/jit_code.h b/paddle/fluid/operators/math/jit_code.h index 6007b290815de0ceaa2226962c5273ae7da72e7e..73692ebc67c71f6190f2d18bd50071a28a35d4c9 100644 --- a/paddle/fluid/operators/math/jit_code.h +++ b/paddle/fluid/operators/math/jit_code.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/math/jit_gen.h" - namespace paddle { namespace operators { namespace math { @@ -29,28 +29,47 @@ using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; using Label = Xbyak::Label; -class VMulJitCode : public JitCode { +// function: vec = Operand(vec, vec) (maybe with relu) +typedef enum { mul = 0, add } operand_type; + +class VVVJitCode : public JitCode { public: - DECLARE_JIT_CODE(VMulJitCode); - explicit VMulJitCode(int d, size_t code_size = 256 * 1024, - void* code_ptr = nullptr) - : JitCode(code_size, code_ptr), num_(d) {} + const char* name() const override { + std::string base = "VVVJitCode"; + if (type_ == operand_type::mul) { + base += "_Mul"; + } else if (type_ == operand_type::add) { + base += "_Add"; + } + base += (with_relu_ ? "_relu" : ""); + return base.c_str(); + } + explicit VVVJitCode(int d, operand_type type, bool with_relu, + size_t code_size = 256 * 1024, void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), + num_(d), + type_(type), + with_relu_(with_relu) {} static bool init(int d); void generate() override; private: int num_; + operand_type type_; + bool with_relu_; reg64_t param1{abi_param1}; reg64_t param2{abi_param2}; reg64_t param3{abi_param3}; xmm_t xmm_src1 = xmm_t(0); xmm_t xmm_src2 = xmm_t(1); - xmm_t xmm_dst = xmm_t(2); + xmm_t xmm_dst = xmm_t(1); + xmm_t xmm_zero = xmm_t(2); ymm_t ymm_src1 = ymm_t(0); ymm_t ymm_src2 = ymm_t(1); - ymm_t ymm_dst = ymm_t(2); + ymm_t ymm_dst = ymm_t(1); + ymm_t ymm_zero = ymm_t(2); }; } // namespace gen diff --git a/paddle/fluid/operators/math/jit_kernel.h b/paddle/fluid/operators/math/jit_kernel.h index 7b6027aa267803ff8ff830deabda536b1b27fec8..04e0b81d3e7c696ac2f5ee78db90fb3c89ab345d 100644 --- a/paddle/fluid/operators/math/jit_kernel.h +++ b/paddle/fluid/operators/math/jit_kernel.h @@ -71,26 +71,26 @@ class VMulKernel : public Kernel { template class VAddKernel : public Kernel { public: - virtual void Compute(const T *x, const T *y, T *z) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template -class VScalKernel : public Kernel { +class VAddReluKernel : public Kernel { public: - virtual void Compute(const T a, const T *x, T *y) const = 0; - virtual void Compute(const T a, T *x) const = 0; + void (*Compute)(const T *, const T *, T *, int); }; template -class VAddBiasKernel : public Kernel { +class VScalKernel : public Kernel { public: virtual void Compute(const T a, const T *x, T *y) const = 0; + virtual void Compute(const T a, T *x) const = 0; }; template -class VAddReluKernel : public Kernel { +class VAddBiasKernel : public Kernel { public: - virtual void Compute(const T *x, const T *y, T *z) const = 0; + virtual void Compute(const T a, const T *x, T *y) const = 0; }; template diff --git a/paddle/fluid/operators/math/jit_kernel_blas.cc b/paddle/fluid/operators/math/jit_kernel_blas.cc index 8a988f8f482e4a4963f70c39bccd89387c1e0059..9acb349f663cca1d38fa7840c3308dfa17a43ab1 100644 --- a/paddle/fluid/operators/math/jit_kernel_blas.cc +++ b/paddle/fluid/operators/math/jit_kernel_blas.cc @@ -42,6 +42,21 @@ void VMulRefer(const T* x, const T* y, T* z, int n) { } } +template +void VAddRefer(const T* x, const T* y, T* z, int n) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + } +} + +template +void VAddReluRefer(const T* x, const T* y, T* z, int n) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] + y[i]; + z[i] = z[i] > 0 ? z[i] : 0; + } +} + #ifdef PADDLE_WITH_MKLML template void VMulMKL(const T* x, const T* y, T* z, int n); @@ -50,28 +65,45 @@ template <> void VMulMKL(const float* x, const float* y, float* z, int n) { platform::dynload::vsMul(n, x, y, z); } + template <> void VMulMKL(const double* x, const double* y, double* z, int n) { platform::dynload::vdMul(n, x, y, z); } + +template +void VAddMKL(const T* x, const T* y, T* z, int n); + +template <> +void VAddMKL(const float* x, const float* y, float* z, int n) { + platform::dynload::vsAdd(n, x, y, z); +} + +template <> +void VAddMKL(const double* x, const double* y, double* z, int n) { + platform::dynload::vdAdd(n, x, y, z); +} #endif +#define DECLARE_STATIC_FUNC \ + static inline std::string name(int d) { \ + PADDLE_THROW("DType should be either float or double"); \ + } \ + static inline bool useJIT(int d) { return false; } \ + static inline bool useMKL(int d) { return false; } + /* VMUL JitKernel */ template class VMulKernelImpl : public VMulKernel { public: - static inline std::string name(int d) { - PADDLE_THROW("DType should be either float or double"); - } - static inline bool useJIT(int d) { return false; } - static inline bool useMKL(int d) { return false; } - + DECLARE_STATIC_FUNC; explicit VMulKernelImpl(int d) : VMulKernel() { #ifdef PADDLE_WITH_XBYAK if (useJIT(d)) { // roughly estimate the size of code size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; - jitcode_.reset(new gen::VMulJitCode(d, sz > 4096 ? sz : 4096)); + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false, + sz > 4096 ? sz : 4096)); this->Compute = jitcode_->getCode(); return; @@ -89,14 +121,14 @@ class VMulKernelImpl : public VMulKernel { #ifdef PADDLE_WITH_XBYAK private: - std::unique_ptr jitcode_{nullptr}; + std::unique_ptr jitcode_{nullptr}; #endif }; #ifdef PADDLE_WITH_XBYAK template <> bool VMulKernelImpl::useJIT(int d) { - return gen::VMulJitCode::init(d); + return gen::VVVJitCode::init(d); } #endif @@ -112,63 +144,89 @@ bool VMulKernelImpl::useMKL(int d) { } #endif -REGISTER_JITKERNEL(vmul, VMulKernel); - -/* VADD JitKernel */ -template +/* VAdd JitKernel */ +template class VAddKernelImpl : public VAddKernel { public: - explicit VAddKernelImpl(int d) : VAddKernel() { this->num_ = d; } - void Compute(const T* x, const T* y, T* z) const override { - for (int i = 0; i < this->num_; ++i) { - z[i] = x[i] + y[i]; + DECLARE_STATIC_FUNC; + explicit VAddKernelImpl(int d) : VAddKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; } - } -}; - +#endif #ifdef PADDLE_WITH_MKLML -#define MKL_FLOAT(isa, block) \ - template <> \ - void VAddKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - platform::dynload::vsAdd(this->num_, x, y, z); \ + if (useMKL(d)) { + this->Compute = VAddMKL; + return; + } +#endif + this->Compute = VAddRefer; } -#define MKL_DOUBLE(isa, block) \ - template <> \ - void VAddKernelImpl::Compute( \ - const double* x, const double* y, double* z) const { \ - platform::dynload::vdAdd(this->num_, x, y, z); \ - } + private: + std::unique_ptr jitcode_{nullptr}; +}; -FOR_EACH_ISA(MKL_FLOAT, kGT16); -FOR_EACH_ISA_BLOCK(MKL_DOUBLE); +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddKernelImpl::useJIT(int d) { + return gen::VVVJitCode::init(d); +} #endif -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 tmpx, tmpy; \ - tmpx = _mm256_loadu_ps(x); \ - tmpy = _mm256_loadu_ps(y); \ - tmpx = _mm256_add_ps(tmpx, tmpy); \ - _mm256_storeu_ps(z, tmpx); \ - } -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); +#ifdef PADDLE_WITH_MKLML +template <> +bool VAddKernelImpl::useMKL(int d) { + return d > 512; +} + +template <> +bool VAddKernelImpl::useMKL(int d) { + return true; +} #endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); + +/* VAddRelu JitKernel */ +template +class VAddReluKernelImpl : public VAddReluKernel { + public: + DECLARE_STATIC_FUNC; + explicit VAddReluKernelImpl(int d) : VAddReluKernel() { +#ifdef PADDLE_WITH_XBYAK + if (useJIT(d)) { + size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; + jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true, + sz > 4096 ? sz : 4096)); + this->Compute = + jitcode_->getCode(); + return; + } #endif -#ifdef __AVX512F__ -INTRI8_FLOAT(jit::avx512f); + this->Compute = VAddReluRefer; + } + + private: + std::unique_ptr jitcode_{nullptr}; +}; + +#ifdef PADDLE_WITH_XBYAK +template <> +bool VAddReluKernelImpl::useJIT(int d) { + return gen::VVVJitCode::init(d); +} #endif -// TODO(TJ): eq16 test and complete avx512 -#undef INTRI8_FLOAT -#undef MKL_FLOAT -#undef MKL_DOUBLE +#undef DECLARE_STATIC_FUNC + +REGISTER_JITKERNEL(vmul, VMulKernel); +REGISTER_JITKERNEL(vadd, VAddKernel); +REGISTER_JITKERNEL(vaddrelu, VAddReluKernel); /* VSCAL JitKernel */ template @@ -405,98 +463,9 @@ class VIdentityKernelImpl : public VIdentityKernel { void Compute(const T* x, T* y) const override {} }; -/* VAddRelu JitKernel */ -template -class VAddReluKernelImpl : public VAddReluKernel { - public: - explicit VAddReluKernelImpl(int d) : VAddReluKernel() { this->num_ = d; } - void Compute(const T* x, const T* y, T* z) const override { - for (int i = 0; i < this->num_; ++i) { - z[i] = x[i] + y[i]; - z[i] = z[i] > 0 ? z[i] : 0; - } - } -}; - -#define INTRI8_FLOAT(isa) \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 tmpx = _mm256_loadu_ps(x); \ - __m256 tmpy = _mm256_loadu_ps(y); \ - tmpy = _mm256_add_ps(tmpx, tmpy); \ - tmpy = _mm256_max_ps(tmpy, _mm256_setzero_ps()); \ - _mm256_storeu_ps(z, tmpy); \ - } - -#define INTRI16_FLOAT(isa) \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - __m256 tmp0 = _mm256_loadu_ps(x); \ - __m256 tmp1 = _mm256_loadu_ps(y); \ - tmp0 = _mm256_add_ps(tmp0, tmp1); \ - tmp0 = _mm256_max_ps(tmp0, zeros); \ - tmp1 = _mm256_loadu_ps(x + 8); \ - __m256 tmp2 = _mm256_loadu_ps(y + 8); \ - tmp1 = _mm256_add_ps(tmp1, tmp2); \ - tmp1 = _mm256_max_ps(tmp1, zeros); \ - _mm256_storeu_ps(z, tmp0); \ - _mm256_storeu_ps(z + 8, tmp1); \ - } - -#define INTRI_COMMON_FLOAT(isa, block) \ - template <> \ - VAddReluKernelImpl::VAddReluKernelImpl(int d) \ - : VAddReluKernel() { \ - this->num_ = d; \ - this->end_ = d - d % AVX_FLOAT_BLOCK; \ - this->rest_ = d - this->end_; \ - } \ - template <> \ - void VAddReluKernelImpl::Compute( \ - const float* x, const float* y, float* z) const { \ - __m256 zeros = _mm256_setzero_ps(); \ - for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \ - __m256 tmpx = _mm256_loadu_ps(x + i); \ - __m256 tmpy = _mm256_loadu_ps(y + i); \ - tmpy = _mm256_add_ps(tmpx, tmpy); \ - tmpy = _mm256_max_ps(tmpy, zeros); \ - _mm256_storeu_ps(z + i, tmpy); \ - } \ - for (int i = this->end_; i < this->num_; ++i) { \ - z[i] = x[i] + y[i]; \ - z[i] = z[i] > 0 ? z[i] : 0; \ - } \ - } - -#ifdef __AVX__ -INTRI8_FLOAT(jit::avx); -INTRI16_FLOAT(jit::avx); -INTRI_COMMON_FLOAT(jit::avx, kGT16); -#endif -#ifdef __AVX2__ -INTRI8_FLOAT(jit::avx2); -INTRI16_FLOAT(jit::avx2); -INTRI_COMMON_FLOAT(jit::avx2, kGT16); -#endif -#ifdef __AVX512F__ -// TODO(TJ): refine avx512 -INTRI8_FLOAT(jit::avx512f); -INTRI16_FLOAT(jit::avx512f); -INTRI_COMMON_FLOAT(jit::avx512f, kGT16); -#endif - -#undef INTRI8_FLOAT -#undef INTRI16_FLOAT -#undef INTRI_COMMON_FLOAT - -REGISTER_JITKERNEL_DEPRECATED(vadd, VAddKernel); REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel); REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel); REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel); -REGISTER_JITKERNEL_DEPRECATED(vaddrelu, VAddReluKernel); REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel); } // namespace jitkernel diff --git a/paddle/fluid/operators/math/jit_kernel_rnn.cc b/paddle/fluid/operators/math/jit_kernel_rnn.cc index d0932a37bb85bbc41f662a106c8ef5693a72efeb..ba3e917377cf12192a068a9d71238442e12d5e5e 100644 --- a/paddle/fluid/operators/math/jit_kernel_rnn.cc +++ b/paddle/fluid/operators/math/jit_kernel_rnn.cc @@ -181,7 +181,7 @@ class LSTMKernelImpl : public LSTMKernel { act_cand_d_->Compute(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); - vadd_d_->Compute(gates + d_, gates + d2_, ct); + vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* H_t = act_cell(C_t) * ogated */ act_cell_d_->Compute(ct, gates + d2_); @@ -291,16 +291,16 @@ class PeepholeKernelImpl : public LSTMKernel { /* get fgated and igated*/ vmul_d_->Compute(wp_data, ct_1, checked, d_); vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_); - vadd_d2_->Compute(checked, gates + d_, gates + d_); + vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_); act_gate_d2_->Compute(gates + d_, gates + d_); /* C_t = C_t-1 * fgated + cand_gated * igated*/ act_cand_d_->Compute(gates, gates); vmul_d_->Compute(gates, gates + d_, gates + d_, d_); vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_); - vadd_d_->Compute(gates + d_, gates + d2_, ct); + vadd_d_->Compute(gates + d_, gates + d2_, ct, d_); /* get ogated*/ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); - vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); act_gate_d_->Compute(gates + d3_, gates + d3_); /* H_t = act_cell(C_t) * ogated */ act_cell_d_->Compute(ct, gates + d2_); @@ -314,7 +314,7 @@ class PeepholeKernelImpl : public LSTMKernel { vmul_d_->Compute(gates, gates + d_, ct, d_); /* get outgated, put W_oc * C_t on igated */ vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_); - vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_); + vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_); /* H_t = act_cell(C_t) * ogated */ act_gate_d_->Compute(gates + d3_, gates + d3_); act_cell_d_->Compute(ct, gates + d2_); diff --git a/paddle/fluid/operators/math/jit_kernel_test.cc b/paddle/fluid/operators/math/jit_kernel_test.cc index 667a95fe1a247cf9d3d63dae74f7e0fa9c2309ca..9a19424691fad70c161ca6036c5cdfd3b2b22ada 100644 --- a/paddle/fluid/operators/math/jit_kernel_test.cc +++ b/paddle/fluid/operators/math/jit_kernel_test.cc @@ -371,7 +371,7 @@ void lstm_ctht_better( vtanh_d->Compute(gates, gates); vmul_d->Compute(gates, gates + d, gates + d, d); vmul_d->Compute(ct_1, gates + d2, gates + d2, d); - vadd_d->Compute(gates + d, gates + d2, ct); + vadd_d->Compute(gates + d, gates + d2, ct, d); /* H_t = act_cell(C_t) * ogated */ vtanh_d->Compute(ct, gates + d2); vmul_d->Compute(gates + d2, gates + d * 3, ht, d); @@ -695,7 +695,7 @@ TEST(JitKernel, vadd) { auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data); + ker->Compute(x_data, y_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); @@ -723,8 +723,8 @@ void vaddrelu_better( const paddle::operators::math::jitkernel::VAddKernel>& vadd, const std::shared_ptr< const paddle::operators::math::jitkernel::VReluKernel>& vrelu, - const float* x, const float* y, float* z) { - vadd->Compute(x, y, z); + const float* x, const float* y, float* z, int d) { + vadd->Compute(x, y, z, d); vrelu->Compute(z, z); } @@ -752,12 +752,12 @@ TEST(JitKernel, vaddrelu) { auto trefe = GetCurrentUS(); auto tmkls = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data); + vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d); } auto tmkle = GetCurrentUS(); auto ttgts = GetCurrentUS(); for (int i = 0; i < repeat; ++i) { - ker->Compute(x_data, y_data, ztgt_data); + ker->Compute(x_data, y_data, ztgt_data, d); } auto ttgte = GetCurrentUS(); VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat @@ -801,7 +801,11 @@ TEST(JitKernel, pool) { std::dynamic_pointer_cast(pvmul_d)); const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfjit4"); - EXPECT_EQ(pvmul_f, pvmul_from_key); +#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32) + EXPECT_EQ(pvmul_from_key, nullptr); +#else + EXPECT_EQ(pvmul_from_key, pvmul_f); +#endif const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit"); EXPECT_TRUE(pvmul_from_key2 == nullptr); } diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index d4ba0f9c33c91811647f9d19a332f139c16b0eb2..673f86da76ee0712b4d941f5b33594f89926b973 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -223,7 +223,9 @@ class TensorRTEngineKernel : public framework::OpKernel { // Add outputs for (auto& output : output_maps) { - engine->DeclareOutput(output); + if (!engine->HasDeclared(output)) { + engine->DeclareOutput(output); + } } engine->FreezeNetwork(); diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 07abe1dd5c426e697d1598c9fa3e07cb48aa435a..2211e5504373b4a30e5fda0db22a41bdcd9f2421 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -116,6 +116,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { platform::SetNumThreads(FLAGS_paddle_num_threads); #endif +#if !defined(_WIN32) && !defined(__APPLE__) && !defined(__OSX__) if (platform::jit::MayIUse(platform::jit::avx)) { #ifndef __AVX__ LOG(WARNING) << "AVX is available, Please re-compile on local machine"; @@ -157,8 +158,9 @@ void InitDevices(bool init_p2p, const std::vector devices) { AVX_GUIDE(AVX, NonAVX); } #endif - #undef AVX_GUIDE + +#endif } void InitGLOG(const std::string &prog_name) { diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index da46a1abe12258b47b2fd4afb5f146daf15e026d..56bf9e31a35fdec5b7f04849068ff96ac9776c0e 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -226,7 +226,7 @@ RecordBlock::~RecordBlock() { void EnableProfiler(ProfilerState state) { PADDLE_ENFORCE(state != ProfilerState::kDisabled, - "Can't enbale profling, since the input state is ", + "Can't enable profiling, since the input state is ", "ProfilerState::kDisabled"); std::lock_guard l(profiler_mu); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fc821e04a0baf9278295da18ee5a69afcf2c4605..238cc19189cfd74afa38bdcb5f5c802f9521dfea 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -742,7 +742,12 @@ All parameter, weight, gradient are variables in Paddle. will clean up the temp variables at the end of the current iteration. 2. In some NLP model, it may cause the GPU memory is insufficient, in this case, you should reduce `num_iteration_per_drop_scope`. - )DOC"); + )DOC") + .def_property("_dry_run", + [](const ExecutionStrategy &self) { return self.dry_run_; }, + [](ExecutionStrategy &self, bool dry_run) { + self.dry_run_ = dry_run; + }); exec_strategy.def_property( "use_experimental_executor", diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 737c8be8147a7efaf9b89827f063430146d3c078..c4cfd8e4680a3564b099eb4d8e3587e45f907572 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -118,7 +118,6 @@ def __bootstrap__(): ] if core.is_compiled_with_dist(): read_env_flags.append('rpc_deadline') - read_env_flags.append('rpc_server_profile_period') read_env_flags.append('rpc_server_profile_path') read_env_flags.append('enable_rpc_profiler') read_env_flags.append('rpc_send_thread_num') diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 22c60c1cbe4faa8577fa655766e42694652e498d..8936d884dd9e1ebbe5f688c11430b64e51ad8bd5 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -65,7 +65,7 @@ def is_persistable(var): Examples: .. code-block:: python - param = fluid.default_main_program().global_block().var('fc.w') + param = fluid.default_main_program().global_block().var('fc.b') res = fluid.io.is_persistable(param) """ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ @@ -625,8 +625,13 @@ def save_inference_model(dirname, main_program._distributed_lookup_table, main_program._endpoints) - if not os.path.isdir(dirname): + # when a pserver and a trainer running on the same machine, mkdir may conflict + try: os.makedirs(dirname) + except OSError as e: + if e.errno != errno.EEXIST: + raise + if model_filename is not None: model_basename = os.path.basename(model_filename) else: diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 80b50022dd1ac5ec739029f6cfff3f7f170ada00..d1c926c4e4d41d55130a37e0bf2492f56fde0658 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -60,7 +60,7 @@ def data(name, For example if shape=[1], the resulting shape is [-1, 1]. 2. If shape contains -1, such as shape=[1, -1], append_batch_size will be enforced to be be False (ineffective). - dtype(int|float): The type of data : float32, float_16, int etc + dtype(basestring): The type of data : float32, float_16, int etc type(VarType): The output type. By default it is LOD_TENSOR. lod_level(int): The LoD Level. 0 means the input data is not a sequence. stop_gradient(bool): A boolean that mentions whether gradient should flow. diff --git a/python/paddle/fluid/recordio_writer.py b/python/paddle/fluid/recordio_writer.py index a69c0c29d4675d3e6b9b2a2d766b8be9935092cf..076a942cdde5623faa570bf98f889e8145b60f8b 100644 --- a/python/paddle/fluid/recordio_writer.py +++ b/python/paddle/fluid/recordio_writer.py @@ -41,9 +41,6 @@ def convert_reader_to_recordio_file( """ Convert a Python Reader to a recordio file. - Please see :ref:`api_guide_python_reader` and :ref:`api_guide_reader_op` for - details. - Examples: >>> import paddle.fluid as fluid diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_dry_run.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_dry_run.py new file mode 100644 index 0000000000000000000000000000000000000000..c93740669f40aee3a6c143d153cfd0f5bb72dbd9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_dry_run.py @@ -0,0 +1,80 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest +import logging +import six + + +class TestBase(unittest.TestCase): + def main(self, + network_func, + iter=100, + iter_per_pe=100, + use_gpu=True, + use_experimental_executor=False): + if use_gpu and not fluid.core.is_compiled_with_cuda(): + logging.warning( + "Paddle is not compiled with CUDA, skip GPU unittests") + return + + main_prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_prog, startup_prog): + with fluid.scope_guard(scope): + loss = network_func() + fluid.Executor( + fluid.CUDAPlace(0) + if use_gpu else fluid.CPUPlace()).run(startup_prog) + + for _ in six.moves.xrange(iter): + exe_strategy = fluid.ExecutionStrategy() + exe_strategy._dry_run = True + exe_strategy.use_experimental_executor = use_experimental_executor + pe = fluid.ParallelExecutor( + use_cuda=True, + loss_name=loss.name, + main_program=main_prog, + exec_strategy=exe_strategy) + for _ in six.moves.xrange(iter_per_pe): + pe.run([]) + + +class TestMNISTDryRun(TestBase): + def test_mnist_dry_run(self): + for use_gpu in (False, True): + for use_experimental_executor in (False, True): + self.main( + network_func=TestMNISTDryRun.network_func, + use_gpu=use_gpu, + use_experimental_executor=use_experimental_executor) + + @staticmethod + def network_func(): + img = fluid.layers.data(name='img', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = img + for _ in six.moves.xrange(10): + hidden = fluid.layers.fc(input=img, size=200, act='tanh') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + fluid.optimizer.Adam().minimize(avg_loss) + return avg_loss + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py index af3745987aa3eae96968bdc6b5c9cd951e9ca6fa..3eecc4670152e72443f731c71d7db67ca8e02e72 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py @@ -14,30 +14,18 @@ from __future__ import print_function -from parallel_executor_test_base import TestParallelExecutorBase -import paddle.fluid as fluid -import paddle.fluid.core as core -import numpy as np -import paddle -import paddle.dataset.mnist as mnist import unittest -import os -MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio" +import numpy as np +import paddle.fluid.core as core +import os +import paddle.fluid as fluid +from parallel_executor_test_base import TestParallelExecutorBase def simple_fc_net(use_feed): - if use_feed: - img = fluid.layers.data(name='image', shape=[784], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - else: - reader = fluid.layers.open_files( - filenames=[MNIST_RECORDIO_FILE], - shapes=[[-1, 784], [-1, 1]], - lod_levels=[0, 0], - dtypes=['float32', 'int64']) - reader = fluid.layers.io.double_buffer(reader) - img, label = fluid.layers.read_file(reader) + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') hidden = img for _ in range(4): hidden = fluid.layers.fc( @@ -53,17 +41,8 @@ def simple_fc_net(use_feed): def fc_with_batchnorm(use_feed): - if use_feed: - img = fluid.layers.data(name='image', shape=[784], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - else: - reader = fluid.layers.open_files( - filenames=[MNIST_RECORDIO_FILE], - shapes=[[-1, 784], [-1, 1]], - lod_levels=[0, 0], - dtypes=['float32', 'int64']) - reader = fluid.layers.io.double_buffer(reader) - img, label = fluid.layers.read_file(reader) + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') hidden = img for _ in range(1): @@ -88,19 +67,6 @@ class TestMNIST(TestParallelExecutorBase): @classmethod def setUpClass(cls): os.environ['CPU_NUM'] = str(4) - # Convert mnist to recordio file - with fluid.program_guard(fluid.Program(), fluid.Program()): - reader = paddle.batch(mnist.train(), batch_size=4) - feeder = fluid.DataFeeder( - feed_list=[ # order is image and label - fluid.layers.data( - name='image', shape=[784]), - fluid.layers.data( - name='label', shape=[1], dtype='int64'), - ], - place=fluid.CPUPlace()) - fluid.recordio_writer.convert_reader_to_recordio_file( - MNIST_RECORDIO_FILE, reader, feeder) def _init_data(self): np.random.seed(5) @@ -111,10 +77,6 @@ class TestMNIST(TestParallelExecutorBase): def _compare_reduce_and_allreduce(self, model, use_cuda): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence( - model, use_cuda=use_cuda, use_reduce=True) - self.check_network_convergence( - model, use_cuda=use_cuda, allow_op_delay=True, use_reduce=True) img, label = self._init_data() @@ -140,9 +102,6 @@ class TestMNIST(TestParallelExecutorBase): def check_simple_fc_convergence(self, use_cuda, use_reduce=False): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence(simple_fc_net, use_cuda=use_cuda) - self.check_network_convergence( - simple_fc_net, use_cuda=use_cuda, allow_op_delay=True) img, label = self._init_data() @@ -199,8 +158,6 @@ class TestMNIST(TestParallelExecutorBase): if use_cuda and not core.is_compiled_with_cuda(): return - self.check_network_convergence(fc_with_batchnorm, use_cuda=use_cuda) - img, label = self._init_data() self.check_network_convergence( diff --git a/python/setup.py.in b/python/setup.py.in index ee19294ad5c884cf73a4f14290f61f0b345ea8c7..b1ff9f3a5c3d877edb6bc6a12efce053a44b4c9c 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -14,7 +14,8 @@ RC = 0 def git_commit(): try: cmd = ['git', 'rev-parse', 'HEAD'] - git_commit = subprocess.Popen(cmd, stdout = subprocess.PIPE).communicate()[0].strip() + git_commit = subprocess.Popen(cmd, stdout = subprocess.PIPE, + cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip() except: git_commit = 'Unknown' git_commit = git_commit.decode() @@ -44,7 +45,7 @@ def get_patch(): def is_taged(): try: cmd = ['git', 'describe', '--exact-match', '--tags', 'HEAD', '2>/dev/null'] - git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE).communicate()[0].strip() + git_tag = subprocess.Popen(cmd, stdout = subprocess.PIPE, cwd="@PADDLE_SOURCE_DIR@").communicate()[0].strip() git_tag = git_tag.decode() except: return False @@ -55,8 +56,7 @@ def is_taged(): return False def write_version_py(filename='paddle/version.py'): - cnt = ''' -# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY + cnt = '''# THIS FILE IS GENERATED FROM PADDLEPADDLE SETUP.PY # full_version = '%(major)d.%(minor)d.%(patch)s' major = '%(major)d'