未验证 提交 2190ea09 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference] move IsFloatVar() from tensorrt/ to api/ (#49070)

* move IsFloatVar() from tensorrt/ to api/
上级 c9951dfc
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/tensorrt/op_teller.h"
...@@ -74,7 +75,7 @@ void OutputProcess(framework::ir::Graph *graph, ...@@ -74,7 +75,7 @@ void OutputProcess(framework::ir::Graph *graph,
for (auto *var_node : op_node->outputs) { for (auto *var_node : op_node->outputs) {
if (!trt_outputs.count(var_node)) continue; if (!trt_outputs.count(var_node)) continue;
if (!var_node->Var()->Persistable() && if (!var_node->Var()->Persistable() &&
tensorrt::IsFloatVar(var_node->Var()->GetDataType()) && IsFloatVar(var_node->Var()->GetDataType()) &&
var_node->Var()->GetDataType() != framework::proto::VarType::FP32) { var_node->Var()->GetDataType() != framework::proto::VarType::FP32) {
for (auto *next_op : var_node->outputs) { for (auto *next_op : var_node->outputs) {
// if next_op support mixed_precision, we need to add cast op. // if next_op support mixed_precision, we need to add cast op.
......
...@@ -102,6 +102,11 @@ cc_test( ...@@ -102,6 +102,11 @@ cc_test(
SRCS api_tester.cc SRCS api_tester.cc
DEPS paddle_inference_api) DEPS paddle_inference_api)
cc_test(
inference_api_helper_test
SRCS helper_test.cc
DEPS paddle_inference_api)
if(WITH_ONNXRUNTIME AND WIN32) if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will # Copy onnxruntime for some c++ test in Windows, since the test will
# be build only in CI, so suppose the generator in Windows is Ninja. # be build only in CI, so suppose the generator in Windows is Ninja.
......
...@@ -81,6 +81,15 @@ inline PaddleDType ConvertToPaddleDType( ...@@ -81,6 +81,15 @@ inline PaddleDType ConvertToPaddleDType(
} }
} }
inline bool IsFloatVar(framework::proto::VarType::Type t) {
if (t == framework::proto::VarType::FP16 ||
t == framework::proto::VarType::FP32 ||
t == framework::proto::VarType::FP64 ||
t == framework::proto::VarType::BF16)
return true;
return false;
}
using paddle::framework::DataTypeToString; using paddle::framework::DataTypeToString;
// Timer for timer // Timer for timer
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/api/helper.h"
#include "gtest/gtest.h"
namespace paddle {
TEST(inference_api_helper, DataType) {
ASSERT_TRUE(
paddle::inference::IsFloatVar(paddle::framework::proto::VarType::FP64));
ASSERT_TRUE(
paddle::inference::IsFloatVar(paddle::framework::proto::VarType::FP32));
ASSERT_TRUE(
paddle::inference::IsFloatVar(paddle::framework::proto::VarType::FP16));
ASSERT_TRUE(
paddle::inference::IsFloatVar(paddle::framework::proto::VarType::BF16));
ASSERT_FALSE(
paddle::inference::IsFloatVar(paddle::framework::proto::VarType::INT32));
}
} // namespace paddle
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -215,14 +214,6 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) { ...@@ -215,14 +214,6 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) {
return nv_type; return nv_type;
} }
static bool IsFloatVar(framework::proto::VarType::Type t) {
if (t == framework::proto::VarType::FP16 ||
t == framework::proto::VarType::FP32 ||
t == framework::proto::VarType::FP64 ||
t == framework::proto::VarType::BF16)
return true;
return false;
}
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册