未验证 提交 82db4993 编写于 作者: W Wilber 提交者: GitHub

cherry-pick 46942 (#47015)

上级 84333cf5
......@@ -129,17 +129,19 @@ phi::DataType ConvertPrecision(AnalysisConfig::Precision precision) {
}
}
phi::Backend ConvertBackend(AnalysisConfig::Backend backend) {
phi::Backend ConvertBackend(paddle_infer::PlaceType backend) {
switch (backend) {
case AnalysisConfig::Backend::kGPU:
case paddle_infer::PlaceType::kGPU:
// NOTE: phi also support phi::Backend::GPUDNN.
return phi::Backend::GPU;
case AnalysisConfig::Backend::kNPU:
case paddle_infer::PlaceType::kNPU:
return phi::Backend::NPU;
case AnalysisConfig::Backend::kXPU:
case paddle_infer::PlaceType::kXPU:
return phi::Backend::XPU;
case AnalysisConfig::Backend::kCPU:
case paddle_infer::PlaceType::kCPU:
return phi::Backend::CPU;
case paddle_infer::PlaceType::kIPU:
return phi::Backend::IPU;
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Paddle Inference not support backend, we now only support GPU, XPU, "
......@@ -2320,7 +2322,7 @@ void ConvertToMixedPrecision(const std::string &model_file,
const std::string &mixed_model_file,
const std::string &mixed_params_file,
PrecisionType mixed_precision,
BackendType backend,
paddle_infer::PlaceType backend,
bool keep_io_types,
std::unordered_set<std::string> black_list) {
auto phi_backend = paddle::ConvertBackend(backend);
......
......@@ -170,13 +170,6 @@ struct PD_INFER_DECL AnalysisConfig {
kBf16, ///< bf16
};
enum class Backend {
kCPU = 0,
kGPU,
kXPU,
kNPU,
};
///
/// \brief Set the no-combined model dir path.
///
......
......@@ -47,7 +47,6 @@ namespace paddle_infer {
using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig;
using DistConfig = paddle::DistConfig;
using BackendType = paddle::AnalysisConfig::Backend;
///
/// \class Predictor
......@@ -198,7 +197,7 @@ PD_INFER_DECL void ConvertToMixedPrecision(
const std::string& mixed_model_file,
const std::string& mixed_params_file,
PrecisionType mixed_precision,
BackendType backend,
PlaceType backend,
bool keep_io_types = true,
std::unordered_set<std::string> black_list = {});
......
......@@ -616,13 +616,6 @@ void BindAnalysisConfig(py::module *m) {
.value("Bfloat16", AnalysisConfig::Precision::kBf16)
.export_values();
py::enum_<AnalysisConfig::Backend>(analysis_config, "Backend")
.value("CPU", AnalysisConfig::Backend::kCPU)
.value("GPU", AnalysisConfig::Backend::kGPU)
.value("NPU", AnalysisConfig::Backend::kNPU)
.value("XPU", AnalysisConfig::Backend::kXPU)
.export_values();
analysis_config.def(py::init<>())
.def(py::init<const AnalysisConfig &>())
.def(py::init<const std::string &>())
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .wrapper import Config, DataType, PlaceType, PrecisionType, BackendType, Tensor, Predictor
from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor
from .wrapper import convert_to_mixed_precision
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
......@@ -24,7 +24,6 @@ from typing import Set
DataType = PaddleDType
PlaceType = PaddlePlace
PrecisionType = AnalysisConfig.Precision
BackendType = AnalysisConfig.Backend
Config = AnalysisConfig
Tensor = PaddleInferTensor
Predictor = PaddleInferPredictor
......@@ -59,7 +58,7 @@ def convert_to_mixed_precision(model_file: str,
mixed_model_file: str,
mixed_params_file: str,
mixed_precision: PrecisionType,
backend: BackendType,
backend: PlaceType,
keep_io_types: bool = True,
black_list: Set = set()):
'''
......@@ -71,7 +70,7 @@ def convert_to_mixed_precision(model_file: str,
mixed_model_file: The storage path of the converted mixed-precision model.
mixed_params_file: The storage path of the converted mixed-precision params.
mixed_precision: The precision, e.g. PrecisionType.Half.
backend: The backend, e.g. BackendType.GPU.
backend: The backend, e.g. PlaceType.GPU.
keep_io_types: Whether the model input and output dtype remains unchanged.
black_list: Operators that do not convert precision.
'''
......
......@@ -20,7 +20,7 @@ from paddle.vision.models import resnet50
from paddle.jit import to_static
from paddle.static import InputSpec
from paddle.inference import PrecisionType, BackendType
from paddle.inference import PrecisionType, PlaceType
from paddle.inference import convert_to_mixed_precision
......@@ -38,7 +38,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams',
'mixed/inference.pdmodel',
'mixed/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, True)
PrecisionType.Half, PlaceType.GPU, True)
def test_convert_to_fp16_with_fp16_input(self):
model = resnet50(True)
......@@ -49,7 +49,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams',
'mixed1/inference.pdmodel',
'mixed1/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, False)
PrecisionType.Half, PlaceType.GPU, False)
def test_convert_to_fp16_with_blacklist(self):
model = resnet50(True)
......@@ -60,7 +60,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams',
'mixed2/inference.pdmodel',
'mixed2/inference.pdiparams',
PrecisionType.Half, BackendType.GPU, False,
PrecisionType.Half, PlaceType.GPU, False,
set('conv2d'))
def test_convert_to_bf16(self):
......@@ -72,8 +72,7 @@ class TestConvertToMixedPrecision(unittest.TestCase):
'resnet50/inference.pdiparams',
'mixed3/inference.pdmodel',
'mixed3/inference.pdiparams',
PrecisionType.Bfloat16, BackendType.GPU,
True)
PrecisionType.Bfloat16, PlaceType.GPU, True)
if __name__ == '__main__':
......
......@@ -16,7 +16,6 @@ from ..fluid.inference import Config # noqa: F401
from ..fluid.inference import DataType # noqa: F401
from ..fluid.inference import PlaceType # noqa: F401
from ..fluid.inference import PrecisionType # noqa: F401
from ..fluid.inference import BackendType # noqa: F401
from ..fluid.inference import Tensor # noqa: F401
from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401
......@@ -28,8 +27,8 @@ from ..fluid.inference import get_num_bytes_of_data_type # noqa: F401
from ..fluid.inference import PredictorPool # noqa: F401
__all__ = [ # noqa
'Config', 'DataType', 'PlaceType', 'PrecisionType', 'BackendType', 'Tensor',
'Predictor', 'create_predictor', 'get_version', 'get_trt_compile_version',
'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor',
'create_predictor', 'get_version', 'get_trt_compile_version',
'convert_to_mixed_precision', 'get_trt_runtime_version',
'get_num_bytes_of_data_type', 'PredictorPool'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册