未验证 提交 fd0051b4 编写于 作者: S ShenLiang 提交者: GitHub

add set default dtype, get default dtype (#26006)

* add set/get default dtype
上级 586a6dd3
...@@ -225,6 +225,8 @@ from .framework import ExponentialDecay #DEFINE_ALIAS ...@@ -225,6 +225,8 @@ from .framework import ExponentialDecay #DEFINE_ALIAS
from .framework import InverseTimeDecay #DEFINE_ALIAS from .framework import InverseTimeDecay #DEFINE_ALIAS
from .framework import PolynomialDecay #DEFINE_ALIAS from .framework import PolynomialDecay #DEFINE_ALIAS
from .framework import CosineDecay #DEFINE_ALIAS from .framework import CosineDecay #DEFINE_ALIAS
from .framework import set_default_dtype #DEFINE_ALIAS
from .framework import get_default_dtype #DEFINE_ALIAS
from .tensor.search import index_sample #DEFINE_ALIAS from .tensor.search import index_sample #DEFINE_ALIAS
from .tensor.stat import mean #DEFINE_ALIAS from .tensor.stat import mean #DEFINE_ALIAS
......
...@@ -283,7 +283,7 @@ class Layer(core.Layer): ...@@ -283,7 +283,7 @@ class Layer(core.Layer):
def create_parameter(self, def create_parameter(self,
shape, shape,
attr=None, attr=None,
dtype='float32', dtype=None,
is_bias=False, is_bias=False,
default_initializer=None): default_initializer=None):
"""Create parameters for this layer. """Create parameters for this layer.
......
...@@ -23,8 +23,13 @@ from .param_attr import ParamAttr, WeightNormParamAttr ...@@ -23,8 +23,13 @@ from .param_attr import ParamAttr, WeightNormParamAttr
from . import core from . import core
from .initializer import _global_weight_initializer, _global_bias_initializer from .initializer import _global_weight_initializer, _global_bias_initializer
__all__ = ['LayerHelperBase']
class LayerHelperBase(object): class LayerHelperBase(object):
# global dtype
__dtype = "float32"
def __init__(self, name, layer_type): def __init__(self, name, layer_type):
self._layer_type = layer_type self._layer_type = layer_type
self._name = name self._name = name
...@@ -45,6 +50,14 @@ class LayerHelperBase(object): ...@@ -45,6 +50,14 @@ class LayerHelperBase(object):
def startup_program(self): def startup_program(self):
return default_startup_program() return default_startup_program()
@classmethod
def set_default_dtype(cls, dtype):
cls.__dtype = dtype
@classmethod
def get_default_dtype(cls):
return cls.__dtype
def to_variable(self, value, name=None): def to_variable(self, value, name=None):
""" """
The API will create a ``Variable`` object from numpy\.ndarray or Variable object. The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
...@@ -277,7 +290,7 @@ class LayerHelperBase(object): ...@@ -277,7 +290,7 @@ class LayerHelperBase(object):
def create_parameter(self, def create_parameter(self,
attr, attr,
shape, shape,
dtype, dtype=None,
is_bias=False, is_bias=False,
default_initializer=None, default_initializer=None,
stop_gradient=False, stop_gradient=False,
...@@ -299,6 +312,9 @@ class LayerHelperBase(object): ...@@ -299,6 +312,9 @@ class LayerHelperBase(object):
if not attr: if not attr:
return None return None
assert isinstance(attr, ParamAttr) assert isinstance(attr, ParamAttr)
# set global dtype
if not dtype:
dtype = self.__dtype
if is_bias: if is_bias:
suffix = 'b' suffix = 'b'
default_initializer = _global_bias_initializer( default_initializer = _global_bias_initializer(
...@@ -372,6 +388,9 @@ class LayerHelperBase(object): ...@@ -372,6 +388,9 @@ class LayerHelperBase(object):
based on operator's `VarTypeInference` implementation in based on operator's `VarTypeInference` implementation in
infer_var_type. infer_var_type.
""" """
# set global dtype
if not dtype:
dtype = self.__dtype
return self.main_program.current_block().create_var( return self.main_program.current_block().create_var(
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(".".join(
[self.name, 'tmp'])), [self.name, 'tmp'])),
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.framework import set_default_dtype, get_default_dtype
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear
import paddle.fluid.core as core
from paddle import to_variable
class TestDefaultType(unittest.TestCase):
def check_default(self):
self.assertEqual("float32", get_default_dtype())
def test_api(self):
self.check_default()
set_default_dtype("float64")
self.assertEqual("float64", get_default_dtype())
set_default_dtype(np.int32)
self.assertEqual("int32", get_default_dtype())
if __name__ == '__main__':
unittest.main()
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# TODO: import framework api under this directory # TODO: import framework api under this directory
__all__ = [ __all__ = [
'create_global_var', 'create_parameter', 'ParamAttr', 'Variable', 'create_global_var', 'create_parameter', 'ParamAttr', 'Variable',
'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace' 'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace', 'get_default_dtype',
'set_default_dtype'
] ]
__all__ += [ __all__ += [
...@@ -30,6 +31,8 @@ __all__ += [ ...@@ -30,6 +31,8 @@ __all__ += [
from . import random from . import random
from .random import manual_seed from .random import manual_seed
from .framework import get_default_dtype
from .framework import set_default_dtype
from ..fluid.framework import Variable #DEFINE_ALIAS from ..fluid.framework import Variable #DEFINE_ALIAS
from ..fluid.framework import ComplexVariable #DEFINE_ALIAS from ..fluid.framework import ComplexVariable #DEFINE_ALIAS
......
...@@ -13,5 +13,46 @@ ...@@ -13,5 +13,46 @@
# limitations under the License. # limitations under the License.
# TODO: define framework api # TODO: define framework api
# __all__ = ['set_default_dtype', from paddle.fluid.layer_helper_base import LayerHelperBase
# 'get_default_dtype'] from paddle.fluid.data_feeder import convert_dtype
__all__ = ['set_default_dtype', 'get_default_dtype']
def set_default_dtype(d):
"""
Set default dtype. The default dtype is initially float32
Args:
d(string|np.dtype): the dtype to make the default
Returns:
None.
Examples:
.. code-block:: python
import paddle
paddle.set_default_dtype("float32")
"""
d = convert_dtype(d)
LayerHelperBase.set_default_dtype(d)
def get_default_dtype():
"""
Get the current default dtype. The default dtype is initially float32
Args:
None.
Returns:
The default dtype.
Examples:
.. code-block:: python
import paddle
paddle.get_default_dtype()
"""
return LayerHelperBase.get_default_dtype()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册