diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index d79e8078ff415f22443a9c1f302604a004819e35..49ec3088831dff415e042e1b0a632f63106eb07b 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np -def convert_to_list(value, n, name): - """Converts a single integer or iterable of integers into an integer list. +def convert_to_list(value, n, name, dtype=np.int): + """ + Converts a single numerical type or iterable of numerical + types into an numerical type list. Arguments: value: The value to validate and convert. Could an int, or any iterable @@ -22,15 +25,16 @@ def convert_to_list(value, n, name): n: The size of the list to be returned. name: The name of the argument being validated, e.g. "stride" or "filter_size". This is only used to format error messages. + dtype: the numerical type of the element of the list to be returned. Returns: - A list of n integers. + A list of n dtypes. Raises: ValueError: If something else than an int/long or iterable thereof was passed. """ - if isinstance(value, int): + if isinstance(value, dtype): return [value, ] * n else: try: @@ -44,11 +48,12 @@ def convert_to_list(value, n, name): ". Received: " + str(value)) for single_value in value_list: try: - int(single_value) + dtype(single_value) except (ValueError, TypeError): raise ValueError( "The " + name + "'s type must be a list or tuple of " + str( - n) + " integers. Received: " + str(value) + " " + n) + " " + str(dtype) + " . Received: " + str( + value) + " " "including element " + str(single_value) + " of type" + " " + str(type(single_value))) return value_list