未验证 提交 82db4993 编写于 作者: W Wilber 提交者: GitHub

cherry-pick 46942 (#47015)

上级 84333cf5
...@@ -129,17 +129,19 @@ phi::DataType ConvertPrecision(AnalysisConfig::Precision precision) { ...@@ -129,17 +129,19 @@ phi::DataType ConvertPrecision(AnalysisConfig::Precision precision) {
} }
} }
phi::Backend ConvertBackend(AnalysisConfig::Backend backend) { phi::Backend ConvertBackend(paddle_infer::PlaceType backend) {
switch (backend) { switch (backend) {
case AnalysisConfig::Backend::kGPU: case paddle_infer::PlaceType::kGPU:
// NOTE: phi also support phi::Backend::GPUDNN. // NOTE: phi also support phi::Backend::GPUDNN.
return phi::Backend::GPU; return phi::Backend::GPU;
case AnalysisConfig::Backend::kNPU: case paddle_infer::PlaceType::kNPU:
return phi::Backend::NPU; return phi::Backend::NPU;
case AnalysisConfig::Backend::kXPU: case paddle_infer::PlaceType::kXPU:
return phi::Backend::XPU; return phi::Backend::XPU;
case AnalysisConfig::Backend::kCPU: case paddle_infer::PlaceType::kCPU:
return phi::Backend::CPU; return phi::Backend::CPU;
case paddle_infer::PlaceType::kIPU:
return phi::Backend::IPU;
default: default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Paddle Inference not support backend, we now only support GPU, XPU, " "Paddle Inference not support backend, we now only support GPU, XPU, "
...@@ -2320,7 +2322,7 @@ void ConvertToMixedPrecision(const std::string &model_file, ...@@ -2320,7 +2322,7 @@ void ConvertToMixedPrecision(const std::string &model_file,
const std::string &mixed_model_file, const std::string &mixed_model_file,
const std::string &mixed_params_file, const std::string &mixed_params_file,
PrecisionType mixed_precision, PrecisionType mixed_precision,
BackendType backend, paddle_infer::PlaceType backend,
bool keep_io_types, bool keep_io_types,
std::unordered_set<std::string> black_list) { std::unordered_set<std::string> black_list) {
auto phi_backend = paddle::ConvertBackend(backend); auto phi_backend = paddle::ConvertBackend(backend);
......
...@@ -170,13 +170,6 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -170,13 +170,6 @@ struct PD_INFER_DECL AnalysisConfig {
kBf16, ///< bf16 kBf16, ///< bf16
}; };
enum class Backend {
kCPU = 0,
kGPU,
kXPU,
kNPU,
};
/// ///
/// \brief Set the no-combined model dir path. /// \brief Set the no-combined model dir path.
/// ///
......
...@@ -47,7 +47,6 @@ namespace paddle_infer { ...@@ -47,7 +47,6 @@ namespace paddle_infer {
using PrecisionType = paddle::AnalysisConfig::Precision; using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig; using Config = paddle::AnalysisConfig;
using DistConfig = paddle::DistConfig; using DistConfig = paddle::DistConfig;
using BackendType = paddle::AnalysisConfig::Backend;
/// ///
/// \class Predictor /// \class Predictor
...@@ -198,7 +197,7 @@ PD_INFER_DECL void ConvertToMixedPrecision( ...@@ -198,7 +197,7 @@ PD_INFER_DECL void ConvertToMixedPrecision(
const std::string& mixed_model_file, const std::string& mixed_model_file,
const std::string& mixed_params_file, const std::string& mixed_params_file,
PrecisionType mixed_precision, PrecisionType mixed_precision,
BackendType backend, PlaceType backend,
bool keep_io_types = true, bool keep_io_types = true,
std::unordered_set<std::string> black_list = {}); std::unordered_set<std::string> black_list = {});
......
...@@ -616,13 +616,6 @@ void BindAnalysisConfig(py::module *m) { ...@@ -616,13 +616,6 @@ void BindAnalysisConfig(py::module *m) {
.value("Bfloat16", AnalysisConfig::Precision::kBf16) .value("Bfloat16", AnalysisConfig::Precision::kBf16)
.export_values(); .export_values();
py::enum_<AnalysisConfig::Backend>(analysis_config, "Backend")
.value("CPU", AnalysisConfig::Backend::kCPU)
.value("GPU", AnalysisConfig::Backend::kGPU)
.value("NPU", AnalysisConfig::Backend::kNPU)
.value("XPU", AnalysisConfig::Backend::kXPU)
.export_values();
analysis_config.def(py::init<>()) analysis_config.def(py::init<>())
.def(py::init<const AnalysisConfig &>()) .def(py::init<const AnalysisConfig &>())
.def(py::init<const std::string &>()) .def(py::init<const std::string &>())
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .wrapper import Config, DataType, PlaceType, PrecisionType, BackendType, Tensor, Predictor from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor
from .wrapper import convert_to_mixed_precision from .wrapper import convert_to_mixed_precision
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
...@@ -24,7 +24,6 @@ from typing import Set ...@@ -24,7 +24,6 @@ from typing import Set
DataType = PaddleDType DataType = PaddleDType
PlaceType = PaddlePlace PlaceType = PaddlePlace
PrecisionType = AnalysisConfig.Precision PrecisionType = AnalysisConfig.Precision
BackendType = AnalysisConfig.Backend
Config = AnalysisConfig Config = AnalysisConfig
Tensor = PaddleInferTensor Tensor = PaddleInferTensor
Predictor = PaddleInferPredictor Predictor = PaddleInferPredictor
...@@ -59,7 +58,7 @@ def convert_to_mixed_precision(model_file: str, ...@@ -59,7 +58,7 @@ def convert_to_mixed_precision(model_file: str,
mixed_model_file: str, mixed_model_file: str,
mixed_params_file: str, mixed_params_file: str,
mixed_precision: PrecisionType, mixed_precision: PrecisionType,
backend: BackendType, backend: PlaceType,
keep_io_types: bool = True, keep_io_types: bool = True,
black_list: Set = set()): black_list: Set = set()):
''' '''
...@@ -71,7 +70,7 @@ def convert_to_mixed_precision(model_file: str, ...@@ -71,7 +70,7 @@ def convert_to_mixed_precision(model_file: str,
mixed_model_file: The storage path of the converted mixed-precision model. mixed_model_file: The storage path of the converted mixed-precision model.
mixed_params_file: The storage path of the converted mixed-precision params. mixed_params_file: The storage path of the converted mixed-precision params.
mixed_precision: The precision, e.g. PrecisionType.Half. mixed_precision: The precision, e.g. PrecisionType.Half.
backend: The backend, e.g. BackendType.GPU. backend: The backend, e.g. PlaceType.GPU.
keep_io_types: Whether the model input and output dtype remains unchanged. keep_io_types: Whether the model input and output dtype remains unchanged.
black_list: Operators that do not convert precision. black_list: Operators that do not convert precision.
''' '''
......
...@@ -20,7 +20,7 @@ from paddle.vision.models import resnet50 ...@@ -20,7 +20,7 @@ from paddle.vision.models import resnet50
from paddle.jit import to_static from paddle.jit import to_static
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.inference import PrecisionType, BackendType from paddle.inference import PrecisionType, PlaceType
from paddle.inference import convert_to_mixed_precision from paddle.inference import convert_to_mixed_precision
...@@ -38,7 +38,7 @@ class TestConvertToMixedPrecision(unittest.TestCase): ...@@ -38,7 +38,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams', 'resnet50/inference.pdiparams',
'mixed/inference.pdmodel', 'mixed/inference.pdmodel',
'mixed/inference.pdiparams', 'mixed/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, True) PrecisionType.Half, PlaceType.GPU, True)
def test_convert_to_fp16_with_fp16_input(self): def test_convert_to_fp16_with_fp16_input(self):
model = resnet50(True) model = resnet50(True)
...@@ -49,7 +49,7 @@ class TestConvertToMixedPrecision(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams', 'resnet50/inference.pdiparams',
'mixed1/inference.pdmodel', 'mixed1/inference.pdmodel',
'mixed1/inference.pdiparams', 'mixed1/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, False) PrecisionType.Half, PlaceType.GPU, False)
def test_convert_to_fp16_with_blacklist(self): def test_convert_to_fp16_with_blacklist(self):
model = resnet50(True) model = resnet50(True)
...@@ -60,7 +60,7 @@ class TestConvertToMixedPrecision(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams', 'resnet50/inference.pdiparams',
'mixed2/inference.pdmodel', 'mixed2/inference.pdmodel',
'mixed2/inference.pdiparams', 'mixed2/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, False, PrecisionType.Half, PlaceType.GPU, False,
set('conv2d')) set('conv2d'))
def test_convert_to_bf16(self): def test_convert_to_bf16(self):
...@@ -72,8 +72,7 @@ class TestConvertToMixedPrecision(unittest.TestCase): ...@@ -72,8 +72,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams', 'resnet50/inference.pdiparams',
'mixed3/inference.pdmodel', 'mixed3/inference.pdmodel',
'mixed3/inference.pdiparams', 'mixed3/inference.pdiparams',
PrecisionType.Bfloat16, BackendType.GPU, PrecisionType.Bfloat16, PlaceType.GPU, True)
True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,7 +16,6 @@ from ..fluid.inference import Config # noqa: F401 ...@@ -16,7 +16,6 @@ from ..fluid.inference import Config # noqa: F401
from ..fluid.inference import DataType # noqa: F401 from ..fluid.inference import DataType # noqa: F401
from ..fluid.inference import PlaceType # noqa: F401 from ..fluid.inference import PlaceType # noqa: F401
from ..fluid.inference import PrecisionType # noqa: F401 from ..fluid.inference import PrecisionType # noqa: F401
from ..fluid.inference import BackendType # noqa: F401
from ..fluid.inference import Tensor # noqa: F401 from ..fluid.inference import Tensor # noqa: F401
from ..fluid.inference import Predictor # noqa: F401 from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401 from ..fluid.inference import create_predictor # noqa: F401
...@@ -28,8 +27,8 @@ from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401 ...@@ -28,8 +27,8 @@ from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401
from ..fluid.inference import PredictorPool # noqa: F401 from ..fluid.inference import PredictorPool # noqa: F401
__all__ = [ # noqa __all__ = [ # noqa
'Config', 'DataType', 'PlaceType', 'PrecisionType', 'BackendType', 'Tensor', 'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor',
'Predictor', 'create_predictor', 'get_version', 'get_trt_compile_version', 'create_predictor', 'get_version', 'get_trt_compile_version',
'convert_to_mixed_precision', 'get_trt_runtime_version', 'convert_to_mixed_precision', 'get_trt_runtime_version',
'get_num_bytes_of_data_type', 'PredictorPool' 'get_num_bytes_of_data_type', 'PredictorPool'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册