wrapper.py 3.2 KB
Newer Older
W
Wilber 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import os
from typing import Set
17

18 19 20 21 22 23 24 25 26 27 28 29
import numpy as np

import paddle.fluid.core as core
from paddle.fluid.core import (
    AnalysisConfig,
    PaddleDType,
    PaddleInferPredictor,
    PaddleInferTensor,
    PaddlePlace,
    convert_to_mixed_precision_bind,
)

W
Wilber 已提交
30 31 32 33 34 35
DataType = PaddleDType
PlaceType = PaddlePlace
PrecisionType = AnalysisConfig.Precision
Config = AnalysisConfig
Tensor = PaddleInferTensor
Predictor = PaddleInferPredictor
36 37 38 39 40 41


def tensor_copy_from_cpu(self, data):
    '''
    Support input type check based on tensor.copy_from_cpu.
    '''
42 43 44
    if isinstance(data, np.ndarray) or (
        isinstance(data, list) and len(data) > 0 and isinstance(data[0], str)
    ):
S
Steffy-zxf 已提交
45 46
        self.copy_from_cpu_bind(data)
    else:
47
        raise TypeError(
S
Steffy-zxf 已提交
48 49
            "In copy_from_cpu, we only support numpy ndarray and list[str] data type."
        )
50 51


52 53 54 55 56 57 58 59
def tensor_share_external_data(self, data):
    '''
    Support input type check based on tensor.share_external_data.
    '''
    if isinstance(data, core.LoDTensor):
        self.share_external_data_bind(data)
    else:
        raise TypeError(
60 61
            "In share_external_data, we only support LoDTensor data type."
        )
62 63


64 65 66 67 68 69 70 71 72 73
def convert_to_mixed_precision(
    model_file: str,
    params_file: str,
    mixed_model_file: str,
    mixed_params_file: str,
    mixed_precision: PrecisionType,
    backend: PlaceType,
    keep_io_types: bool = True,
    black_list: Set = set(),
):
74 75 76 77 78 79 80 81 82
    '''
    Convert a fp32 model to mixed precision model.

    Args:
        model_file: fp32 model file, e.g. inference.pdmodel.
        params_file: fp32 params file, e.g. inference.pdiparams.
        mixed_model_file: The storage path of the converted mixed-precision model.
        mixed_params_file: The storage path of the converted mixed-precision params.
        mixed_precision: The precision, e.g. PrecisionType.Half.
83
        backend: The backend, e.g. PlaceType.GPU.
84 85 86 87 88 89 90 91 92
        keep_io_types: Whether the model input and output dtype remains unchanged.
        black_list: Operators that do not convert precision.
    '''
    mixed_model_dirname = os.path.dirname(mixed_model_file)
    mixed_params_dirname = os.path.dirname(mixed_params_file)
    if not os.path.exists(mixed_model_dirname):
        os.makedirs(mixed_model_dirname)
    if not os.path.exists(mixed_params_dirname):
        os.makedirs(mixed_params_dirname)
93 94 95 96 97 98 99 100 101 102
    convert_to_mixed_precision_bind(
        model_file,
        params_file,
        mixed_model_file,
        mixed_params_file,
        mixed_precision,
        backend,
        keep_io_types,
        black_list,
    )
103 104


105
Tensor.copy_from_cpu = tensor_copy_from_cpu
106
Tensor.share_external_data = tensor_share_external_data