未验证 提交 f12b5260 编写于 作者: H heliqi 提交者: GitHub

Optimize the onnxruntime code (#41044)

上级 b1ee9d5e
......@@ -23,7 +23,8 @@
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#include "onnxruntime_c_api.h" // NOLINT
#include "onnxruntime_cxx_api.h" // NOLINT
#endif
namespace paddle_infer {
......
......@@ -81,7 +81,7 @@ bool CheckConvertToONNX(const AnalysisConfig &config) {
bool ONNXRuntimePredictor::Init() {
VLOG(3) << "ONNXRuntime Predictor::init()";
// Now ONNXRuntime only suuport CPU
// Now ONNXRuntime only support CPU
const char *device_name = config_.use_gpu() ? "Cuda" : "Cpu";
if (config_.use_gpu()) {
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
......
......@@ -49,10 +49,6 @@ TEST(ONNXRuntimePredictor, onnxruntime_on) {
ASSERT_TRUE(predictor);
ASSERT_TRUE(!predictor->Clone());
ASSERT_TRUE(predictor->scope_);
ASSERT_TRUE(predictor->sub_scope_);
ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
// Dummy Input Data
std::vector<int64_t> input_shape = {-1, 3, 224, 224};
std::vector<float> input_data(1 * 3 * 224 * 224, 1.0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册