utils.py 2.3 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
C
chengduoZH 已提交
16
import numpy as np
C
chengduoZH 已提交
17 18


C
chengduoZH 已提交
19 20 21 22
def convert_to_list(value, n, name, dtype=np.int):
    """
    Converts a single numerical type or iterable of numerical
    types into an numerical type list.
C
chengduoZH 已提交
23 24 25 26 27 28 29

    Arguments:
      value: The value to validate and convert. Could an int, or any iterable
        of ints.
      n: The size of the list to be returned.
      name: The name of the argument being validated, e.g. "stride" or
        "filter_size". This is only used to format error messages.
C
chengduoZH 已提交
30
      dtype: the numerical type of the element of the list to be returned.
C
chengduoZH 已提交
31 32

    Returns:
C
chengduoZH 已提交
33
      A list of n dtypes.
C
chengduoZH 已提交
34 35 36 37 38

    Raises:
      ValueError: If something else than an int/long or iterable thereof was
        passed.
    """
C
chengduoZH 已提交
39
    if isinstance(value, dtype):
C
chengduoZH 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52
        return [value, ] * n
    else:
        try:
            value_list = list(value)
        except TypeError:
            raise ValueError("The " + name +
                             "'s type must be list or tuple. Received: " + str(
                                 value))
        if len(value_list) != n:
            raise ValueError("The " + name + "'s length must be " + str(n) +
                             ". Received: " + str(value))
        for single_value in value_list:
            try:
C
chengduoZH 已提交
53
                dtype(single_value)
C
chengduoZH 已提交
54 55
            except (ValueError, TypeError):
                raise ValueError(
C
chengduoZH 已提交
56
                    "The " + name + "'s type must be a list or tuple of " + str(
C
chengduoZH 已提交
57 58
                        n) + " " + str(dtype) + " . Received: " + str(
                            value) + " "
C
chengduoZH 已提交
59 60 61
                    "including element " + str(single_value) + " of type" + " "
                    + str(type(single_value)))
        return value_list