未验证 提交 b160d09e 编写于 作者: A Aurelius84 提交者: GitHub

[JIT] Add Predictor for JITLayer (#47379)

* add predictor_engine

* add predictor_engine

* fix zero shape

* fix lodTensor

* fix unittest

* fix code style

* update CmakeList
上级 0972d6ac
...@@ -150,6 +150,9 @@ struct Argument { ...@@ -150,6 +150,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool); DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
// For JITLayer
DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
// The overall Scope to work on. // The overall Scope to work on.
......
...@@ -55,7 +55,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { ...@@ -55,7 +55,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
argument->model_params_path(), argument->model_params_path(),
argument->scope_ptr(), argument->scope_ptr(),
place, place,
argument->model_from_memory_valid() && argument->model_from_memory()); argument->model_from_memory_valid() && argument->model_from_memory(),
argument->skip_load_params());
argument->SetMainProgram(program.release()); argument->SetMainProgram(program.release());
} else { } else {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
...@@ -114,10 +115,11 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel( ...@@ -114,10 +115,11 @@ std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &params_path, const std::string &params_path,
framework::Scope *scope, framework::Scope *scope,
const platform::Place &place, const platform::Place &place,
bool model_from_memory) { bool model_from_memory,
bool skip_load_params) {
framework::Executor exe(place); framework::Executor exe(place);
if (!model_from_memory) { if (!model_from_memory) {
return Load(&exe, scope, program_path, params_path); return Load(&exe, scope, program_path, params_path, !skip_load_params);
} else { } else {
return LoadFromMemory(&exe, scope, program_path, params_path); return LoadFromMemory(&exe, scope, program_path, params_path);
} }
......
...@@ -43,7 +43,8 @@ class IrGraphBuildPass : public AnalysisPass { ...@@ -43,7 +43,8 @@ class IrGraphBuildPass : public AnalysisPass {
const std::string &params_path, const std::string &params_path,
framework::Scope *scope, framework::Scope *scope,
const platform::Place &place, const platform::Place &place,
bool model_from_memory); bool model_from_memory,
bool skip_load_params);
std::string model_binary_str_; std::string model_binary_str_;
}; };
......
...@@ -484,6 +484,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -484,6 +484,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(custom_device_type_); CP_MEMBER(custom_device_type_);
CP_MEMBER(custom_device_id_); CP_MEMBER(custom_device_id_);
// JITLayer relate
CP_MEMBER(apply_optim_);
CP_MEMBER(skip_load_params_);
if (use_gpu_) { if (use_gpu_) {
PADDLE_ENFORCE_EQ(use_xpu_, PADDLE_ENFORCE_EQ(use_xpu_,
false, false,
......
...@@ -168,20 +168,27 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, ...@@ -168,20 +168,27 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
LOG(ERROR) << "unsupported feed type " << pt.dtype; LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false; return false;
} }
// NOTE(Aurelius84): Some kernels support zero shape input
PADDLE_ENFORCE_NOT_NULL( // without memory holder, we should skip enforce logic.
input_ptr, bool has_zero_dim = (phi::product(ddim) == 0);
paddle::platform::errors::Fatal( if (has_zero_dim) {
"Cannot convert to LoDTensor because LoDTensor creation failed.")); VLOG(3) << "Found zero dim from input with ddim: " << ddim;
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pt.data.data(), input_ptr,
paddle::platform::errors::InvalidArgument( paddle::platform::errors::Fatal(
"The data contained in the input PaddleTensor is illegal.")); "Cannot convert to LoDTensor because LoDTensor creation failed."));
PADDLE_ENFORCE_NOT_NULL(
pt.data.data(),
paddle::platform::errors::InvalidArgument(
"The data contained in the input PaddleTensor is illegal."));
}
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy( if (input_ptr != nullptr) {
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length()); std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
}
} else if (platform::is_ipu_place(place)) { } else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
std::memcpy( std::memcpy(
...@@ -529,6 +536,11 @@ bool AnalysisPredictor::PrepareProgram( ...@@ -529,6 +536,11 @@ bool AnalysisPredictor::PrepareProgram(
// If the program is passed from external, no need to optimize it, this // If the program is passed from external, no need to optimize it, this
// logic is used in the clone scenario. // logic is used in the clone scenario.
inference_program_ = program; inference_program_ = program;
if (config_.apply_optim_) {
VLOG(3)
<< "apply_optim is enabled, will call OptimizeInferenceProgram().";
OptimizeInferenceProgram();
}
} }
executor_->CreateVariables(*inference_program_, 0, false, sub_scope_); executor_->CreateVariables(*inference_program_, 0, false, sub_scope_);
...@@ -1065,11 +1077,12 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1065,11 +1077,12 @@ void AnalysisPredictor::PrepareArgument() {
false, false,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Either model_dir or prog_file should be set.")); "Either model_dir or prog_file should be set."));
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
argument_.SetModelProgramPath(config_.prog_file()); argument_.SetModelProgramPath(config_.prog_file());
argument_.SetModelParamsPath(config_.params_file()); argument_.SetModelParamsPath(config_.params_file());
} }
// For JITLayer
argument_.SetSkipLoadParams(config_.skip_load_params_);
argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_); argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_);
......
...@@ -965,6 +965,10 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -965,6 +965,10 @@ struct PD_INFER_DECL AnalysisConfig {
void Exp_SetBlackListOpsForMixedModel( void Exp_SetBlackListOpsForMixedModel(
const std::unordered_set<std::string>& black_list); const std::unordered_set<std::string>& black_list);
void SetApplyOptim(bool value) { apply_optim_ = value; }
void SetSkipLoadParams(bool value) { skip_load_params_ = value; }
protected: protected:
// Update the config. // Update the config.
void Update(); void Update();
...@@ -1167,6 +1171,13 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1167,6 +1171,13 @@ struct PD_INFER_DECL AnalysisConfig {
// fleet exe related // fleet exe related
DistConfig dist_config_{}; DistConfig dist_config_{};
// jit engine related
// NOTE(Aureliue84): In case of Predictor in JITLayer, program is from outer
// which means Predictor should apply optimization by calling
// PrepareProgram(). So we add this flag to control the process.
bool apply_optim_{false};
bool skip_load_params_{false};
}; };
} // namespace paddle } // namespace paddle
...@@ -160,11 +160,11 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -160,11 +160,11 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
return main_program; return main_program;
} }
std::unique_ptr<framework::ProgramDesc> Load( std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
framework::Executor* executor, framework::Scope* scope,
framework::Scope* scope, const std::string& prog_filename,
const std::string& prog_filename, const std::string& param_filename,
const std::string& param_filename) { bool load_params) {
std::string program_desc_str; std::string program_desc_str;
ReadBinaryFile(prog_filename, &program_desc_str); ReadBinaryFile(prog_filename, &program_desc_str);
...@@ -175,13 +175,14 @@ std::unique_ptr<framework::ProgramDesc> Load( ...@@ -175,13 +175,14 @@ std::unique_ptr<framework::ProgramDesc> Load(
true, true,
platform::errors::Unavailable("Model version %ld is not supported.", platform::errors::Unavailable("Model version %ld is not supported.",
main_program->Version())); main_program->Version()));
if (load_params) {
LoadPersistables(executor, LoadPersistables(executor,
scope, scope,
*main_program, *main_program,
"", "",
param_filename, param_filename,
false /* model_from_memory */); false /* model_from_memory */);
}
return main_program; return main_program;
} }
......
...@@ -42,7 +42,8 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -42,7 +42,8 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
framework::Scope* scope, framework::Scope* scope,
const std::string& prog_filename, const std::string& prog_filename,
const std::string& param_filename); const std::string& param_filename,
bool load_params = true);
std::unique_ptr<framework::ProgramDesc> LoadFromMemory( std::unique_ptr<framework::ProgramDesc> LoadFromMemory(
framework::Executor* executor, framework::Executor* executor,
......
...@@ -35,7 +35,7 @@ cc_library( ...@@ -35,7 +35,7 @@ cc_library(
jit_function jit_function
SRCS function.cc SRCS function.cc
DEPS jit_function_utils jit_executor_engine jit_pe_engine DEPS jit_function_utils jit_executor_engine jit_pe_engine
jit_interpreter_engine) jit_interpreter_engine jit_predictor_engine)
cc_library( cc_library(
jit_layer jit_layer
...@@ -48,6 +48,7 @@ cc_library( ...@@ -48,6 +48,7 @@ cc_library(
jit_executor_engine jit_executor_engine
jit_pe_engine jit_pe_engine
jit_interpreter_engine jit_interpreter_engine
jit_predictor_engine
jit_function) jit_function)
if(WITH_TESTING AND NOT WIN32) if(WITH_TESTING AND NOT WIN32)
......
...@@ -12,3 +12,8 @@ cc_library( ...@@ -12,3 +12,8 @@ cc_library(
jit_interpreter_engine jit_interpreter_engine
SRCS interpreter_engine.cc SRCS interpreter_engine.cc
DEPS standalone_executor) DEPS standalone_executor)
cc_library(
jit_predictor_engine
SRCS predictor_engine.cc
DEPS paddle_inference_api analysis_predictor)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/jit/engine/predictor_engine.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/paddle_api.h"
#include "paddle/fluid/jit/function_utils.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace jit {
static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t);
static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
DenseTensor *t,
const platform::Place &place);
PredictorEngine::PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
const VariableMap &params_dict,
const phi::Place &place)
: info_(info), scope_(new framework::Scope()), place_(place) {
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict, scope_.get());
VLOG(6) << framework::GenScopeTreeDebugInfo(scope_.get());
AnalysisConfig config;
config.SetProgFile(info->ProgramFilePath());
if (platform::is_gpu_place(place_)) {
config.EnableUseGpu(100, place_.GetDeviceId());
} else if (platform::is_cpu_place(place_)) {
config.DisableGpu();
}
config.SetSkipLoadParams(true);
config.SetApplyOptim(true);
config.SwitchIrOptim(true);
predictor_.reset(new AnalysisPredictor(config));
predictor_->Init(
scope_, std::make_shared<framework::ProgramDesc>(info_->ProgramDesc()));
}
std::vector<Tensor> PredictorEngine::operator()(
const std::vector<Tensor> &inputs) {
auto dense_tensors = utils::ToDenseTensors(inputs);
return utils::ToTensors(this->operator()(dense_tensors));
}
std::vector<DenseTensor> PredictorEngine::operator()(
const std::vector<DenseTensor> &inputs) {
for (auto t : inputs) {
VLOG(1) << "inputs is init: " << t.initialized();
}
std::vector<PaddleTensor> pt_inputs;
std::vector<PaddleTensor> pt_outputs;
for (auto &t : inputs) {
auto non_const_t = const_cast<DenseTensor *>(&t);
pt_inputs.emplace_back(DenseTensorToPaddleTensor(non_const_t));
}
predictor_->Run(pt_inputs, &pt_outputs);
std::vector<DenseTensor> outputs;
for (auto &pt : pt_outputs) {
DenseTensor t;
PaddleTensorToDenseTensor(pt, &t, place_);
outputs.emplace_back(t);
}
return outputs;
}
static PaddleTensor DenseTensorToPaddleTensor(DenseTensor *t) {
PaddleTensor pt;
if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::INT32) {
pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
pt.dtype = PaddleDType::INT32;
} else if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::INT64) {
pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64;
} else if (framework::TransToProtoVarType(t->dtype()) ==
framework::proto::VarType::FP32) {
pt.data.Reset(t->data(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported tensor date type. Now only supports INT64, FP32, INT32."));
}
pt.shape = phi::vectorize<int>(t->dims());
return pt;
}
static bool PaddleTensorToDenseTensor(const PaddleTensor &pt,
DenseTensor *t,
const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr;
if (pt.dtype == PaddleDType::INT64) {
input_ptr = t->mutable_data<int64_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT32) {
input_ptr = t->mutable_data<float>(ddim, place);
} else if (pt.dtype == PaddleDType::INT32) {
input_ptr = t->mutable_data<int32_t>(ddim, place);
} else if (pt.dtype == PaddleDType::FLOAT16) {
input_ptr = t->mutable_data<float16>(ddim, place);
} else {
LOG(ERROR) << "unsupported feed type " << pt.dtype;
return false;
}
PADDLE_ENFORCE_NOT_NULL(
input_ptr,
paddle::platform::errors::Fatal(
"Cannot convert to LoDTensor because LoDTensor creation failed."));
PADDLE_ENFORCE_NOT_NULL(
pt.data.data(),
paddle::platform::errors::InvalidArgument(
"The data contained in the input PaddleTensor is illegal."));
if (platform::is_cpu_place(place)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
} else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with WITH_IPU, should not reach here."));
#endif
} else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
false,
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<const phi::GPUContext *>(pool.Get(place));
auto dst_gpu_place = place;
memory::Copy(dst_gpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with CUDA, should not reach here."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto dst_xpu_place = place;
memory::Copy(dst_xpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with XPU, should not reach here."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"The analysis predictor supports CPU, GPU and XPU now."));
}
return true;
}
} // namespace jit
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/jit/engine/base_engine.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/function_utils.h"
namespace paddle {
class AnalysisPredictor;
namespace framework {
class Scope;
}
namespace jit {
class PredictorEngine : public BaseEngine {
public:
PredictorEngine(const std::shared_ptr<FunctionInfo> &info,
const VariableMap &params_dict,
const phi::Place &place);
~PredictorEngine() noexcept {}
std::vector<Tensor> operator()(const std::vector<Tensor> &inputs);
std::vector<DenseTensor> operator()(const std::vector<DenseTensor> &inputs);
private:
std::shared_ptr<FunctionInfo> info_;
std::shared_ptr<framework::Scope> scope_;
phi::Place place_;
std::shared_ptr<AnalysisPredictor> predictor_;
};
} // namespace jit
} // namespace paddle
...@@ -82,6 +82,14 @@ const std::vector<std::string> FunctionInfo::OutputArgNames() const { ...@@ -82,6 +82,14 @@ const std::vector<std::string> FunctionInfo::OutputArgNames() const {
return schema_.OutputArgNames(); return schema_.OutputArgNames();
} }
const std::string& FunctionInfo::ProgramFilePath() const {
return prog_file_path_;
}
void FunctionInfo::SetProgramFilePath(const std::string& path) {
prog_file_path_ = path;
}
void FunctionInfo::RemoveDescFeedFetch() { void FunctionInfo::RemoveDescFeedFetch() {
utils::RemoveFeedFetch(program_desc_.get()); utils::RemoveFeedFetch(program_desc_.get());
} }
......
...@@ -72,6 +72,10 @@ class FunctionInfo { ...@@ -72,6 +72,10 @@ class FunctionInfo {
const std::vector<std::string> OutputArgNames() const; const std::vector<std::string> OutputArgNames() const;
const std::string& ProgramFilePath() const;
void SetProgramFilePath(const std::string& path);
void RemoveDescFeedFetch(); void RemoveDescFeedFetch();
private: private:
...@@ -79,6 +83,7 @@ class FunctionInfo { ...@@ -79,6 +83,7 @@ class FunctionInfo {
std::vector<std::string> param_names_; std::vector<std::string> param_names_;
std::shared_ptr<framework::ProgramDesc> program_desc_; std::shared_ptr<framework::ProgramDesc> program_desc_;
FunctionSchema schema_; FunctionSchema schema_;
std::string prog_file_path_;
}; };
} // namespace jit } // namespace jit
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/jit/engine/executor_engine.h" #include "paddle/fluid/jit/engine/executor_engine.h"
#include "paddle/fluid/jit/engine/interpreter_engine.h" #include "paddle/fluid/jit/engine/interpreter_engine.h"
#include "paddle/fluid/jit/engine/pe_engine.h" #include "paddle/fluid/jit/engine/pe_engine.h"
#include "paddle/fluid/jit/engine/predictor_engine.h"
#include "paddle/fluid/jit/layer.h" #include "paddle/fluid/jit/layer.h"
#include "paddle/fluid/jit/property.h" #include "paddle/fluid/jit/property.h"
#include "paddle/fluid/jit/serializer_utils.h" #include "paddle/fluid/jit/serializer_utils.h"
...@@ -54,6 +55,7 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -54,6 +55,7 @@ Layer Deserializer::operator()(const std::string& path,
param_names_set.insert(persist_var_names.begin(), persist_var_names.end()); param_names_set.insert(persist_var_names.begin(), persist_var_names.end());
info_map[func_name] = std::make_shared<FunctionInfo>( info_map[func_name] = std::make_shared<FunctionInfo>(
func_name, persist_var_names, program_desc); func_name, persist_var_names, program_desc);
info_map[func_name]->SetProgramFilePath(it.second);
} }
VariableMap params_dict; VariableMap params_dict;
...@@ -70,22 +72,23 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -70,22 +72,23 @@ Layer Deserializer::operator()(const std::string& path,
for (auto it = info_map.begin(); it != info_map.end(); ++it) { for (auto it = info_map.begin(); it != info_map.end(); ++it) {
const std::string& func_name = it->first; const std::string& func_name = it->first;
auto& info = it->second; auto& info = it->second;
VLOG(3) << "Add function type: " << FLAGS_jit_engine_type
<< " Function name: " << func_name;
if (FLAGS_jit_engine_type == "Executor") { if (FLAGS_jit_engine_type == "Executor") {
VLOG(3) << "Add function type: ExecutorEngine. Function name: "
<< func_name;
layer.SetEngine( layer.SetEngine(
func_name, func_name,
utils::MakeEngine<ExecutorEngine>(info, params_dict, place)); utils::MakeEngine<ExecutorEngine>(info, params_dict, place));
} else if (FLAGS_jit_engine_type == "PE") { } else if (FLAGS_jit_engine_type == "PE") {
VLOG(3) << "Add function type: PEEngine. Function name: " << func_name;
layer.SetEngine(func_name, layer.SetEngine(func_name,
utils::MakeEngine<PEEngine>(info, params_dict, place)); utils::MakeEngine<PEEngine>(info, params_dict, place));
} else if (FLAGS_jit_engine_type == "New") { } else if (FLAGS_jit_engine_type == "New") {
VLOG(3) << "Add function type: InterpreterEngine. Function name: "
<< func_name;
layer.SetEngine( layer.SetEngine(
func_name, func_name,
utils::MakeEngine<InterpreterEngine>(info, params_dict, place)); utils::MakeEngine<InterpreterEngine>(info, params_dict, place));
} else if (FLAGS_jit_engine_type == "Predictor") {
layer.SetEngine(
info->FunctionName(),
utils::MakeEngine<PredictorEngine>(info, params_dict, place));
} else { } else {
PD_THROW("Invalid JitLayer engine type."); PD_THROW("Invalid JitLayer engine type.");
} }
......
...@@ -17,6 +17,10 @@ foreach(src ${OPS}) ...@@ -17,6 +17,10 @@ foreach(src ${OPS})
${COLLECTIVE_COMPILE_FLAGS}) ${COLLECTIVE_COMPILE_FLAGS})
endforeach() endforeach()
if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
endif()
register_operators( register_operators(
EXCLUDES EXCLUDES
c_gen_bkcl_id_op c_gen_bkcl_id_op
...@@ -35,10 +39,6 @@ if(WITH_NCCL OR WITH_RCCL) ...@@ -35,10 +39,6 @@ if(WITH_NCCL OR WITH_RCCL)
op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
endif() endif()
if(WITH_GLOO)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} gloo_wrapper)
endif()
if(WITH_XPU_BKCL) if(WITH_XPU_BKCL)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} collective_helper)
op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(c_gen_bkcl_id_op DEPS ${COLLECTIVE_DEPS})
......
...@@ -4,3 +4,7 @@ if(WITH_UNITY_BUILD) ...@@ -4,3 +4,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake) include(unity_build_rule.cmake)
endif() endif()
register_operators() register_operators()
if(WITH_UNITY_BUILD)
target_link_libraries(paddle_operators_sequence_ops_unity sequence_pooling)
endif()
...@@ -1010,13 +1010,14 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -1010,13 +1010,14 @@ PADDLE_DEFINE_EXPORTED_bool(
* Name: FLAGS_jit_engine_type * Name: FLAGS_jit_engine_type
* Since Version: 2.3.0 * Since Version: 2.3.0
* Value Range: string, {Executor, PE}, * Value Range: string, {Executor, PE},
* default=PE * default=Predictor
* Example: * Example:
* Note: * Note:
* FLAGS_jit_engine_type == Executor, using ExecutorEngine by default * FLAGS_jit_engine_type == Executor, using ExecutorEngine by default
* FLAGS_jit_engine_type == PE, using PEEngine by default * FLAGS_jit_engine_type == PE, using PEEngine by default
* FLAGS_jit_engine_type == New, using InterpreterEngine by default * FLAGS_jit_engine_type == New, using InterpreterEngine by default
* FLAGS_jit_engine_type == Predictor, using inference Predictor by default
*/ */
PADDLE_DEFINE_EXPORTED_string(jit_engine_type, PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"PE", "Predictor",
"Choose default funciton type in JitLayer."); "Choose default funciton type in JitLayer.");
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册