diff --git a/test/executor_for_test.h b/test/executor_for_test.h index a54a8bb191ad53bf5581a246f8a0ead633f84102..89b546178261a95862122654a0530f59e8ec32cc 100644 --- a/test/executor_for_test.h +++ b/test/executor_for_test.h @@ -114,4 +114,29 @@ class Executor4Test : public Executor { return output_tensor_sptrs; } + + std::shared_ptr predict(const Tensor &t, string input, string output, + const DDim &dDim) { + auto scope = this->program_.scope; + Variable *g_feed_value = scope->Var(input); + auto tensor = g_feed_value->GetMutable(); + tensor->ShareDataWith(t); + + Variable *con_output = scope->Var(output); + auto *output_tensor = con_output->GetMutable(); + output_tensor->mutable_data(dDim); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + std::shared_ptr to_predict_block = + this->to_predict_program_->Block(0); + for (int j = 0; j < this->ops_of_block_[*to_predict_block.get()].size(); + ++j) { + auto op = this->ops_of_block_[*to_predict_block.get()][j]; + op->Run(); + } + + return out_tensor; + } };