未验证 提交 fd0051b4 编写于 作者: S ShenLiang 提交者: GitHub

add set default dtype, get default dtype (#26006)

* add set/get default dtype
上级 586a6dd3
......@@ -225,6 +225,8 @@ from .framework import ExponentialDecay #DEFINE_ALIAS
from .framework import InverseTimeDecay #DEFINE_ALIAS
from .framework import PolynomialDecay #DEFINE_ALIAS
from .framework import CosineDecay #DEFINE_ALIAS
from .framework import set_default_dtype #DEFINE_ALIAS
from .framework import get_default_dtype #DEFINE_ALIAS
from .tensor.search import index_sample #DEFINE_ALIAS
from .tensor.stat import mean #DEFINE_ALIAS
......
......@@ -283,7 +283,7 @@ class Layer(core.Layer):
def create_parameter(self,
shape,
attr=None,
dtype='float32',
dtype=None,
is_bias=False,
default_initializer=None):
"""Create parameters for this layer.
......
......@@ -23,8 +23,13 @@ from .param_attr import ParamAttr, WeightNormParamAttr
from . import core
from .initializer import _global_weight_initializer, _global_bias_initializer
__all__ = ['LayerHelperBase']
class LayerHelperBase(object):
# global dtype
__dtype = "float32"
def __init__(self, name, layer_type):
self._layer_type = layer_type
self._name = name
......@@ -45,6 +50,14 @@ class LayerHelperBase(object):
def startup_program(self):
return default_startup_program()
@classmethod
def set_default_dtype(cls, dtype):
cls.__dtype = dtype
@classmethod
def get_default_dtype(cls):
return cls.__dtype
def to_variable(self, value, name=None):
"""
The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
......@@ -277,7 +290,7 @@ class LayerHelperBase(object):
def create_parameter(self,
attr,
shape,
dtype,
dtype=None,
is_bias=False,
default_initializer=None,
stop_gradient=False,
......@@ -299,6 +312,9 @@ class LayerHelperBase(object):
if not attr:
return None
assert isinstance(attr, ParamAttr)
# set global dtype
if not dtype:
dtype = self.__dtype
if is_bias:
suffix = 'b'
default_initializer = _global_bias_initializer(
......@@ -372,6 +388,9 @@ class LayerHelperBase(object):
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
# set global dtype
if not dtype:
dtype = self.__dtype
return self.main_program.current_block().create_var(
name=unique_name.generate_with_ignorable_key(".".join(
[self.name, 'tmp'])),
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.framework import set_default_dtype, get_default_dtype
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import paddle.fluid.core as core
from paddle import to_variable
class TestDefaultType(unittest.TestCase):
def check_default(self):
self.assertEqual("float32", get_default_dtype())
def test_api(self):
self.check_default()
set_default_dtype("float64")
self.assertEqual("float64", get_default_dtype())
set_default_dtype(np.int32)
self.assertEqual("int32", get_default_dtype())
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,8 @@
# TODO: import framework api under this directory
__all__ = [
'create_global_var', 'create_parameter', 'ParamAttr', 'Variable',
'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace'
'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace', 'get_default_dtype',
'set_default_dtype'
]
__all__ += [
......@@ -30,6 +31,8 @@ __all__ += [
from . import random
from .random import manual_seed
from .framework import get_default_dtype
from .framework import set_default_dtype
from ..fluid.framework import Variable #DEFINE_ALIAS
from ..fluid.framework import ComplexVariable #DEFINE_ALIAS
......
......@@ -13,5 +13,46 @@
# limitations under the License.
# TODO: define framework api
# __all__ = ['set_default_dtype',
# 'get_default_dtype']
from paddle.fluid.layer_helper_base import LayerHelperBase
from paddle.fluid.data_feeder import convert_dtype
__all__ = ['set_default_dtype', 'get_default_dtype']
def set_default_dtype(d):
"""
Set default dtype. The default dtype is initially float32
Args:
d(string|np.dtype): the dtype to make the default
Returns:
None.
Examples:
.. code-block:: python
import paddle
paddle.set_default_dtype("float32")
"""
d = convert_dtype(d)
LayerHelperBase.set_default_dtype(d)
def get_default_dtype():
"""
Get the current default dtype. The default dtype is initially float32
Args:
None.
Returns:
The default dtype.
Examples:
.. code-block:: python
import paddle
paddle.get_default_dtype()
"""
return LayerHelperBase.get_default_dtype()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册