From b160d09eeb9453d05575beb88bc82703791ec977 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 27 Oct 2022 21:03:28 +0800 Subject: [PATCH] [JIT] Add Predictor for JITLayer (#47379) * add predictor_engine * add predictor_engine * fix zero shape * fix lodTensor * fix unittest * fix code style * update CmakeList --- paddle/fluid/inference/analysis/argument.h | 3 + .../analysis/passes/ir_graph_build_pass.cc | 8 +- .../analysis/passes/ir_graph_build_pass.h | 3 +- paddle/fluid/inference/api/analysis_config.cc | 4 + .../fluid/inference/api/analysis_predictor.cc | 37 ++-- .../inference/api/paddle_analysis_config.h | 11 ++ paddle/fluid/inference/io.cc | 25 +-- paddle/fluid/inference/io.h | 3 +- paddle/fluid/jit/CMakeLists.txt | 3 +- paddle/fluid/jit/engine/CMakeLists.txt | 5 + paddle/fluid/jit/engine/predictor_engine.cc | 186 ++++++++++++++++++ paddle/fluid/jit/engine/predictor_engine.h | 50 +++++ paddle/fluid/jit/function_schema.cc | 8 + paddle/fluid/jit/function_schema.h | 5 + paddle/fluid/jit/serializer.cc | 13 +- .../fluid/operators/collective/CMakeLists.txt | 8 +- .../operators/sequence_ops/CMakeLists.txt | 4 + paddle/fluid/platform/flags.cc | 5 +- 18 files changed, 340 insertions(+), 41 deletions(-) create mode 100644 paddle/fluid/jit/engine/predictor_engine.cc create mode 100644 paddle/fluid/jit/engine/predictor_engine.h diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 52d332b2e34..d855dc999ca 100755 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -150,6 +150,9 @@ struct Argument { DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool); + // For JITLayer + DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool); + // The overall graph to work on. DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); // The overall Scope to work on. diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index cd93238ff2b..e07eaa64615 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -55,7 +55,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { argument->model_params_path(), argument->scope_ptr(), place, - argument->model_from_memory_valid() && argument->model_from_memory()); + argument->model_from_memory_valid() && argument->model_from_memory(), + argument->skip_load_params()); argument->SetMainProgram(program.release()); } else { PADDLE_THROW(platform::errors::PreconditionNotMet( @@ -114,10 +115,11 @@ std::unique_ptr IrGraphBuildPass::LoadModel( const std::string ¶ms_path, framework::Scope *scope, const platform::Place &place, - bool model_from_memory) { + bool model_from_memory, + bool skip_load_params) { framework::Executor exe(place); if (!model_from_memory) { - return Load(&exe, scope, program_path, params_path); + return Load(&exe, scope, program_path, params_path, !skip_load_params); } else { return LoadFromMemory(&exe, scope, program_path, params_path); } diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h index 32902ef0667..69047b73ea0 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h @@ -43,7 +43,8 @@ class IrGraphBuildPass : public AnalysisPass { const std::string ¶ms_path, framework::Scope *scope, const platform::Place &place, - bool model_from_memory); + bool model_from_memory, + bool skip_load_params); std::string model_binary_str_; }; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 2af92c7e148..be09976bc4d 100755 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -484,6 +484,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(custom_device_type_); CP_MEMBER(custom_device_id_); + // JITLayer relate + CP_MEMBER(apply_optim_); + CP_MEMBER(skip_load_params_); + if (use_gpu_) { PADDLE_ENFORCE_EQ(use_xpu_, false, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f49c9faeb3d..a78a768a700 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -168,20 +168,27 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, LOG(ERROR) << "unsupported feed type " << pt.dtype; return false; } - - PADDLE_ENFORCE_NOT_NULL( - input_ptr, - paddle::platform::errors::Fatal( - "Cannot convert to LoDTensor because LoDTensor creation failed.")); - PADDLE_ENFORCE_NOT_NULL( - pt.data.data(), - paddle::platform::errors::InvalidArgument( - "The data contained in the input PaddleTensor is illegal.")); + // NOTE(Aurelius84): Some kernels support zero shape input + // without memory holder, we should skip enforce logic. + bool has_zero_dim = (phi::product(ddim) == 0); + if (has_zero_dim) { + VLOG(3) << "Found zero dim from input with ddim: " << ddim; + PADDLE_ENFORCE_NOT_NULL( + input_ptr, + paddle::platform::errors::Fatal( + "Cannot convert to LoDTensor because LoDTensor creation failed.")); + PADDLE_ENFORCE_NOT_NULL( + pt.data.data(), + paddle::platform::errors::InvalidArgument( + "The data contained in the input PaddleTensor is illegal.")); + } if (platform::is_cpu_place(place)) { // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. - std::memcpy( - static_cast(input_ptr), pt.data.data(), pt.data.length()); + if (input_ptr != nullptr) { + std::memcpy( + static_cast(input_ptr), pt.data.data(), pt.data.length()); + } } else if (platform::is_ipu_place(place)) { #ifdef PADDLE_WITH_IPU std::memcpy( @@ -529,6 +536,11 @@ bool AnalysisPredictor::PrepareProgram( // If the program is passed from external, no need to optimize it, this // logic is used in the clone scenario. inference_program_ = program; + if (config_.apply_optim_) { + VLOG(3) + << "apply_optim is enabled, will call OptimizeInferenceProgram()."; + OptimizeInferenceProgram(); + } } executor_->CreateVariables(*inference_program_, 0, false, sub_scope_); @@ -1065,11 +1077,12 @@ void AnalysisPredictor::PrepareArgument() { false, platform::errors::PreconditionNotMet( "Either model_dir or prog_file should be set.")); - std::string dir = inference::analysis::GetDirRoot(config_.prog_file()); argument_.SetModelProgramPath(config_.prog_file()); argument_.SetModelParamsPath(config_.params_file()); } + // For JITLayer + argument_.SetSkipLoadParams(config_.skip_load_params_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index c5a4fd5934c..5bc50515bf4 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -965,6 +965,10 @@ struct PD_INFER_DECL AnalysisConfig { void Exp_SetBlackListOpsForMixedModel( const std::unordered_set& black_list); + void SetApplyOptim(bool value) { apply_optim_ = value; } + + void SetSkipLoadParams(bool value) { skip_load_params_ = value; } + protected: // Update the config. void Update(); @@ -1167,6 +1171,13 @@ struct PD_INFER_DECL AnalysisConfig { // fleet exe related DistConfig dist_config_{}; + + // jit engine related + // NOTE(Aureliue84): In case of Predictor in JITLayer, program is from outer + // which means Predictor should apply optimization by calling + // PrepareProgram(). So we add this flag to control the process. + bool apply_optim_{false}; + bool skip_load_params_{false}; }; } // namespace paddle diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index cad5903540b..253df637633 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -160,11 +160,11 @@ std::unique_ptr Load(framework::Executor* executor, return main_program; } -std::unique_ptr Load( - framework::Executor* executor, - framework::Scope* scope, - const std::string& prog_filename, - const std::string& param_filename) { +std::unique_ptr Load(framework::Executor* executor, + framework::Scope* scope, + const std::string& prog_filename, + const std::string& param_filename, + bool load_params) { std::string program_desc_str; ReadBinaryFile(prog_filename, &program_desc_str); @@ -175,13 +175,14 @@ std::unique_ptr Load( true, platform::errors::Unavailable("Model version %ld is not supported.", main_program->Version())); - - LoadPersistables(executor, - scope, - *main_program, - "", - param_filename, - false /* model_from_memory */); + if (load_params) { + LoadPersistables(executor, + scope, + *main_program, + "", + param_filename, + false /* model_from_memory */); + } return main_program; } diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index 31ed29e425d..36e21f8f36e 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -42,7 +42,8 @@ std::unique_ptr Load(framework::Executor* executor, std::unique_ptr Load(framework::Executor* executor, framework::Scope* scope, const std::string& prog_filename, - const std::string& param_filename); + const std::string& param_filename, + bool load_params = true); std::unique_ptr LoadFromMemory( framework::Executor* executor, diff --git a/paddle/fluid/jit/CMakeLists.txt b/paddle/fluid/jit/CMakeLists.txt index f47de23b0e1..b6db37d82c3 100644 --- a/paddle/fluid/jit/CMakeLists.txt +++ b/paddle/fluid/jit/CMakeLists.txt @@ -35,7 +35,7 @@ cc_library( jit_function SRCS function.cc DEPS jit_function_utils jit_executor_engine jit_pe_engine - jit_interpreter_engine) + jit_interpreter_engine jit_predictor_engine) cc_library( jit_layer @@ -48,6 +48,7 @@ cc_library( jit_executor_engine jit_pe_engine jit_interpreter_engine + jit_predictor_engine jit_function) if(WITH_TESTING AND NOT WIN32) diff --git a/paddle/fluid/jit/engine/CMakeLists.txt b/paddle/fluid/jit/engine/CMakeLists.txt index 5626e9eb1fc..b09e818227d 100644 --- a/paddle/fluid/jit/engine/CMakeLists.txt +++ b/paddle/fluid/jit/engine/CMakeLists.txt @@ -12,3 +12,8 @@ cc_library( jit_interpreter_engine SRCS interpreter_engine.cc DEPS standalone_executor) + +cc_library( + jit_predictor_engine + SRCS predictor_engine.cc + DEPS paddle_inference_api analysis_predictor) diff --git a/paddle/fluid/jit/engine/predictor_engine.cc b/paddle/fluid/jit/engine/predictor_engine.cc new file mode 100644 index 00000000000..d6bdf42b041 --- /dev/null +++ b/paddle/fluid/jit/engine/predictor_engine.cc @@ -0,0 +1,186 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/jit/engine/predictor_engine.h" + +#include "paddle/fluid/inference/api/analysis_predictor.h" +#include "paddle/fluid/inference/api/paddle_api.h" +#include "paddle/fluid/jit/function_utils.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace jit { + +static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t); +static bool PaddleTensorToDenseTensor(const PaddleTensor &pt, + DenseTensor *t, + const platform::Place &place); + +PredictorEngine::PredictorEngine(const std::shared_ptr &info, + const VariableMap ¶ms_dict, + const phi::Place &place) + : info_(info), scope_(new framework::Scope()), place_(place) { + utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get()); + VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get()); + + AnalysisConfig config; + config.SetProgFile(info->ProgramFilePath()); + if (platform::is_gpu_place(place_)) { + config.EnableUseGpu(100, place_.GetDeviceId()); + } else if (platform::is_cpu_place(place_)) { + config.DisableGpu(); + } + config.SetSkipLoadParams(true); + config.SetApplyOptim(true); + config.SwitchIrOptim(true); + + predictor_.reset(new AnalysisPredictor(config)); + + predictor_->Init( + scope_, std::make_shared(info_->ProgramDesc())); +} + +std::vector PredictorEngine::operator()( + const std::vector &inputs) { + auto dense_tensors = utils::ToDenseTensors(inputs); + return utils::ToTensors(this->operator()(dense_tensors)); +} + +std::vector PredictorEngine::operator()( + const std::vector &inputs) { + for (auto t : inputs) { + VLOG(1) << "inputs is init: " << t.initialized(); + } + + std::vector pt_inputs; + std::vector pt_outputs; + for (auto &t : inputs) { + auto non_const_t = const_cast(&t); + pt_inputs.emplace_back(DenseTensorToPaddleTensor(non_const_t)); + } + + predictor_->Run(pt_inputs, &pt_outputs); + + std::vector outputs; + for (auto &pt : pt_outputs) { + DenseTensor t; + PaddleTensorToDenseTensor(pt, &t, place_); + outputs.emplace_back(t); + } + + return outputs; +} + +static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) { + PaddleTensor pt; + + if (framework::TransToProtoVarType(t->dtype()) == + framework::proto::VarType::INT32) { + pt.data.Reset(t->data(), t->numel() * sizeof(int32_t)); + pt.dtype = PaddleDType::INT32; + } else if (framework::TransToProtoVarType(t->dtype()) == + framework::proto::VarType::INT64) { + pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); + pt.dtype = PaddleDType::INT64; + } else if (framework::TransToProtoVarType(t->dtype()) == + framework::proto::VarType::FP32) { + pt.data.Reset(t->data(), t->numel() * sizeof(float)); + pt.dtype = PaddleDType::FLOAT32; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported tensor date type. Now only supports INT64, FP32, INT32.")); + } + pt.shape = phi::vectorize(t->dims()); + return pt; +} + +static bool PaddleTensorToDenseTensor(const PaddleTensor &pt, + DenseTensor *t, + const platform::Place &place) { + framework::DDim ddim = phi::make_ddim(pt.shape); + void *input_ptr; + if (pt.dtype == PaddleDType::INT64) { + input_ptr = t->mutable_data(ddim, place); + } else if (pt.dtype == PaddleDType::FLOAT32) { + input_ptr = t->mutable_data(ddim, place); + } else if (pt.dtype == PaddleDType::INT32) { + input_ptr = t->mutable_data(ddim, place); + } else if (pt.dtype == PaddleDType::FLOAT16) { + input_ptr = t->mutable_data(ddim, place); + } else { + LOG(ERROR) << "unsupported feed type " << pt.dtype; + return false; + } + + PADDLE_ENFORCE_NOT_NULL( + input_ptr, + paddle::platform::errors::Fatal( + "Cannot convert to LoDTensor because LoDTensor creation failed.")); + PADDLE_ENFORCE_NOT_NULL( + pt.data.data(), + paddle::platform::errors::InvalidArgument( + "The data contained in the input PaddleTensor is illegal.")); + + if (platform::is_cpu_place(place)) { + // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. + std::memcpy( + static_cast(input_ptr), pt.data.data(), pt.data.length()); + } else if (platform::is_ipu_place(place)) { +#ifdef PADDLE_WITH_IPU + std::memcpy( + static_cast(input_ptr), pt.data.data(), pt.data.length()); +#else + PADDLE_THROW(paddle::platform::errors::Fatal( + "Not compile with WITH_IPU, should not reach here.")); +#endif + } else if (platform::is_gpu_place(place)) { + PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), + false, + platform::errors::InvalidArgument( + "Only one choice can be made between CPU and XPU.")); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *dev_ctx = static_cast(pool.Get(place)); + auto dst_gpu_place = place; + memory::Copy(dst_gpu_place, + static_cast(input_ptr), + platform::CPUPlace(), + pt.data.data(), + pt.data.length(), + dev_ctx->stream()); +#else + PADDLE_THROW(paddle::platform::errors::Fatal( + "Not compile with CUDA, should not reach here.")); +#endif + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + auto dst_xpu_place = place; + memory::Copy(dst_xpu_place, + static_cast(input_ptr), + platform::CPUPlace(), + pt.data.data(), + pt.data.length()); +#else + PADDLE_THROW(paddle::platform::errors::Fatal( + "Not compile with XPU, should not reach here.")); +#endif + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The analysis predictor supports CPU, GPU and XPU now.")); + } + return true; +} + +} // namespace jit +} // namespace paddle diff --git a/paddle/fluid/jit/engine/predictor_engine.h b/paddle/fluid/jit/engine/predictor_engine.h new file mode 100644 index 00000000000..026b012cbfb --- /dev/null +++ b/paddle/fluid/jit/engine/predictor_engine.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/jit/engine/base_engine.h" +#include "paddle/fluid/jit/function_schema.h" +#include "paddle/fluid/jit/function_utils.h" + +namespace paddle { +class AnalysisPredictor; + +namespace framework { +class Scope; +} + +namespace jit { + +class PredictorEngine : public BaseEngine { + public: + PredictorEngine(const std::shared_ptr &info, + const VariableMap ¶ms_dict, + const phi::Place &place); + + ~PredictorEngine() noexcept {} + + std::vector operator()(const std::vector &inputs); + + std::vector operator()(const std::vector &inputs); + + private: + std::shared_ptr info_; + std::shared_ptr scope_; + phi::Place place_; + std::shared_ptr predictor_; +}; + +} // namespace jit +} // namespace paddle diff --git a/paddle/fluid/jit/function_schema.cc b/paddle/fluid/jit/function_schema.cc index 8150d3b2e75..0d2014153e1 100644 --- a/paddle/fluid/jit/function_schema.cc +++ b/paddle/fluid/jit/function_schema.cc @@ -82,6 +82,14 @@ const std::vector FunctionInfo::OutputArgNames() const { return schema_.OutputArgNames(); } +const std::string& FunctionInfo::ProgramFilePath() const { + return prog_file_path_; +} + +void FunctionInfo::SetProgramFilePath(const std::string& path) { + prog_file_path_ = path; +} + void FunctionInfo::RemoveDescFeedFetch() { utils::RemoveFeedFetch(program_desc_.get()); } diff --git a/paddle/fluid/jit/function_schema.h b/paddle/fluid/jit/function_schema.h index 9f593dd7eee..1a760805584 100644 --- a/paddle/fluid/jit/function_schema.h +++ b/paddle/fluid/jit/function_schema.h @@ -72,6 +72,10 @@ class FunctionInfo { const std::vector OutputArgNames() const; + const std::string& ProgramFilePath() const; + + void SetProgramFilePath(const std::string& path); + void RemoveDescFeedFetch(); private: @@ -79,6 +83,7 @@ class FunctionInfo { std::vector param_names_; std::shared_ptr program_desc_; FunctionSchema schema_; + std::string prog_file_path_; }; } // namespace jit diff --git a/paddle/fluid/jit/serializer.cc b/paddle/fluid/jit/serializer.cc index 8e8bb370e81..9c819c52718 100644 --- a/paddle/fluid/jit/serializer.cc +++ b/paddle/fluid/jit/serializer.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/jit/engine/executor_engine.h" #include "paddle/fluid/jit/engine/interpreter_engine.h" #include "paddle/fluid/jit/engine/pe_engine.h" +#include "paddle/fluid/jit/engine/predictor_engine.h" #include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/property.h" #include "paddle/fluid/jit/serializer_utils.h" @@ -54,6 +55,7 @@ Layer Deserializer::operator()(const std::string& path, param_names_set.insert(persist_var_names.begin(), persist_var_names.end()); info_map[func_name] = std::make_shared( func_name, persist_var_names, program_desc); + info_map[func_name]->SetProgramFilePath(it.second); } VariableMap params_dict; @@ -70,22 +72,23 @@ Layer Deserializer::operator()(const std::string& path, for (auto it = info_map.begin(); it != info_map.end(); ++it) { const std::string& func_name = it->first; auto& info = it->second; + VLOG(3) << "Add function type: " << FLAGS_jit_engine_type + << " Function name: " << func_name; if (FLAGS_jit_engine_type == "Executor") { - VLOG(3) << "Add function type: ExecutorEngine. Function name: " - << func_name; layer.SetEngine( func_name, utils::MakeEngine(info, params_dict, place)); } else if (FLAGS_jit_engine_type == "PE") { - VLOG(3) << "Add function type: PEEngine. Function name: " << func_name; layer.SetEngine(func_name, utils::MakeEngine(info, params_dict, place)); } else if (FLAGS_jit_engine_type == "New") { - VLOG(3) << "Add function type: InterpreterEngine. Function name: " - << func_name; layer.SetEngine( func_name, utils::MakeEngine(info, params_dict, place)); + } else if (FLAGS_jit_engine_type == "Predictor") { + layer.SetEngine( + info->FunctionName(), + utils::MakeEngine(info, params_dict, place)); } else { PD_THROW("Invalid JitLayer engine type."); } diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index c94b0c93eb3..e29b3f6639f 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -17,6 +17,10 @@ foreach(src ${OPS}) ${COLLECTIVE_COMPILE_FLAGS}) endforeach() +if(WITH_GLOO) + set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper) +endif() + register_operators( EXCLUDES c_gen_bkcl_id_op @@ -35,10 +39,6 @@ if(WITH_NCCL OR WITH_RCCL) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) endif() -if(WITH_GLOO) - set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper) -endif() - if(WITH_XPU_BKCL) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS}) diff --git a/paddle/fluid/operators/sequence_ops/CMakeLists.txt b/paddle/fluid/operators/sequence_ops/CMakeLists.txt index fe36afd96c5..06281b6f376 100644 --- a/paddle/fluid/operators/sequence_ops/CMakeLists.txt +++ b/paddle/fluid/operators/sequence_ops/CMakeLists.txt @@ -4,3 +4,7 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() register_operators() + +if(WITH_UNITY_BUILD) + target_link_libraries(paddle_operators_sequence_ops_unity sequence_pooling) +endif() diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 813171240da..bac075c1d90 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -1010,13 +1010,14 @@ PADDLE_DEFINE_EXPORTED_bool( * Name: FLAGS_jit_engine_type * Since Version: 2.3.0 * Value Range: string, {Executor, PE}, - * default=PE + * default=Predictor * Example: * Note: * FLAGS_jit_engine_type == Executor, using ExecutorEngine by default * FLAGS_jit_engine_type == PE, using PEEngine by default * FLAGS_jit_engine_type == New, using InterpreterEngine by default + * FLAGS_jit_engine_type == Predictor, using inference Predictor by default */ PADDLE_DEFINE_EXPORTED_string(jit_engine_type, - "PE", + "Predictor", "Choose default funciton type in JitLayer."); -- GitLab