未验证 提交 526790e6 编写于 作者: Y Yan Chunwei 提交者: GitHub

infer get program (#15511)

上级 3c224e7e
...@@ -726,6 +726,10 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() { ...@@ -726,6 +726,10 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() {
return need; return need;
} }
std::string AnalysisPredictor::GetSeriazlizedProgram() const {
return inference_program_->Proto()->SerializeAsString();
}
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<contrib::AnalysisConfig>( std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<contrib::AnalysisConfig>(
const contrib::AnalysisConfig &config) { const contrib::AnalysisConfig &config) {
......
...@@ -75,6 +75,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -75,6 +75,8 @@ class AnalysisPredictor : public PaddlePredictor {
void SetMkldnnThreadID(int tid); void SetMkldnnThreadID(int tid);
std::string GetSeriazlizedProgram() const override;
protected: protected:
// For memory optimization. // For memory optimization.
bool need_collect_var_shapes_for_memory_optim(); bool need_collect_var_shapes_for_memory_optim();
......
...@@ -215,6 +215,8 @@ TEST(AnalysisPredictor, memory_optim) { ...@@ -215,6 +215,8 @@ TEST(AnalysisPredictor, memory_optim) {
{ {
// The first predictor help to cache the memory optimize strategy. // The first predictor help to cache the memory optimize strategy.
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config); auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
LOG(INFO) << "serialized program: " << predictor->GetSeriazlizedProgram();
ASSERT_FALSE(predictor->GetSeriazlizedProgram().empty());
// Run several times to check the parameters are not reused by mistake. // Run several times to check the parameters are not reused by mistake.
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
......
...@@ -215,6 +215,14 @@ class PaddlePredictor { ...@@ -215,6 +215,14 @@ class PaddlePredictor {
*/ */
virtual ~PaddlePredictor() = default; virtual ~PaddlePredictor() = default;
/** \brief Get the serialized model program that executes in inference phase.
* Its data type is ProgramDesc, which is a protobuf message.
*/
virtual std::string GetSeriazlizedProgram() const {
assert(false); // Force raise error.
return "NotImplemented";
};
/** The common configs for all the predictors. /** The common configs for all the predictors.
*/ */
struct Config { struct Config {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册