Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fd0051b4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fd0051b4
编写于
8月 18, 2020
作者:
S
ShenLiang
提交者:
GitHub
8月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add set default dtype, get default dtype (#26006)
* add set/get default dtype
上级
586a6dd3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
111 addition
and
5 deletion
+111
-5
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/dygraph/layers.py
python/paddle/fluid/dygraph/layers.py
+1
-1
python/paddle/fluid/layer_helper_base.py
python/paddle/fluid/layer_helper_base.py
+20
-1
python/paddle/fluid/tests/unittests/test_default_dtype.py
python/paddle/fluid/tests/unittests/test_default_dtype.py
+41
-0
python/paddle/framework/__init__.py
python/paddle/framework/__init__.py
+4
-1
python/paddle/framework/framework.py
python/paddle/framework/framework.py
+43
-2
未找到文件。
python/paddle/__init__.py
浏览文件 @
fd0051b4
...
...
@@ -225,6 +225,8 @@ from .framework import ExponentialDecay #DEFINE_ALIAS
from
.framework
import
InverseTimeDecay
#DEFINE_ALIAS
from
.framework
import
PolynomialDecay
#DEFINE_ALIAS
from
.framework
import
CosineDecay
#DEFINE_ALIAS
from
.framework
import
set_default_dtype
#DEFINE_ALIAS
from
.framework
import
get_default_dtype
#DEFINE_ALIAS
from
.tensor.search
import
index_sample
#DEFINE_ALIAS
from
.tensor.stat
import
mean
#DEFINE_ALIAS
...
...
python/paddle/fluid/dygraph/layers.py
浏览文件 @
fd0051b4
...
...
@@ -283,7 +283,7 @@ class Layer(core.Layer):
def
create_parameter
(
self
,
shape
,
attr
=
None
,
dtype
=
'float32'
,
dtype
=
None
,
is_bias
=
False
,
default_initializer
=
None
):
"""Create parameters for this layer.
...
...
python/paddle/fluid/layer_helper_base.py
浏览文件 @
fd0051b4
...
...
@@ -23,8 +23,13 @@ from .param_attr import ParamAttr, WeightNormParamAttr
from
.
import
core
from
.initializer
import
_global_weight_initializer
,
_global_bias_initializer
__all__
=
[
'LayerHelperBase'
]
class
LayerHelperBase
(
object
):
# global dtype
__dtype
=
"float32"
def
__init__
(
self
,
name
,
layer_type
):
self
.
_layer_type
=
layer_type
self
.
_name
=
name
...
...
@@ -45,6 +50,14 @@ class LayerHelperBase(object):
def
startup_program
(
self
):
return
default_startup_program
()
@
classmethod
def
set_default_dtype
(
cls
,
dtype
):
cls
.
__dtype
=
dtype
@
classmethod
def
get_default_dtype
(
cls
):
return
cls
.
__dtype
def
to_variable
(
self
,
value
,
name
=
None
):
"""
The API will create a ``Variable`` object from numpy\.ndarray or Variable object.
...
...
@@ -277,7 +290,7 @@ class LayerHelperBase(object):
def
create_parameter
(
self
,
attr
,
shape
,
dtype
,
dtype
=
None
,
is_bias
=
False
,
default_initializer
=
None
,
stop_gradient
=
False
,
...
...
@@ -299,6 +312,9 @@ class LayerHelperBase(object):
if
not
attr
:
return
None
assert
isinstance
(
attr
,
ParamAttr
)
# set global dtype
if
not
dtype
:
dtype
=
self
.
__dtype
if
is_bias
:
suffix
=
'b'
default_initializer
=
_global_bias_initializer
(
...
...
@@ -372,6 +388,9 @@ class LayerHelperBase(object):
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
# set global dtype
if
not
dtype
:
dtype
=
self
.
__dtype
return
self
.
main_program
.
current_block
().
create_var
(
name
=
unique_name
.
generate_with_ignorable_key
(
"."
.
join
(
[
self
.
name
,
'tmp'
])),
...
...
python/paddle/fluid/tests/unittests/test_default_dtype.py
0 → 100644
浏览文件 @
fd0051b4
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
paddle.framework
import
set_default_dtype
,
get_default_dtype
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Linear
import
paddle.fluid.core
as
core
from
paddle
import
to_variable
class
TestDefaultType
(
unittest
.
TestCase
):
def
check_default
(
self
):
self
.
assertEqual
(
"float32"
,
get_default_dtype
())
def
test_api
(
self
):
self
.
check_default
()
set_default_dtype
(
"float64"
)
self
.
assertEqual
(
"float64"
,
get_default_dtype
())
set_default_dtype
(
np
.
int32
)
self
.
assertEqual
(
"int32"
,
get_default_dtype
())
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/framework/__init__.py
浏览文件 @
fd0051b4
...
...
@@ -15,7 +15,8 @@
# TODO: import framework api under this directory
__all__
=
[
'create_global_var'
,
'create_parameter'
,
'ParamAttr'
,
'Variable'
,
'CPUPlace'
,
'CUDAPlace'
,
'CUDAPinnedPlace'
'CPUPlace'
,
'CUDAPlace'
,
'CUDAPinnedPlace'
,
'get_default_dtype'
,
'set_default_dtype'
]
__all__
+=
[
...
...
@@ -30,6 +31,8 @@ __all__ += [
from
.
import
random
from
.random
import
manual_seed
from
.framework
import
get_default_dtype
from
.framework
import
set_default_dtype
from
..fluid.framework
import
Variable
#DEFINE_ALIAS
from
..fluid.framework
import
ComplexVariable
#DEFINE_ALIAS
...
...
python/paddle/framework/framework.py
浏览文件 @
fd0051b4
...
...
@@ -13,5 +13,46 @@
# limitations under the License.
# TODO: define framework api
# __all__ = ['set_default_dtype',
# 'get_default_dtype']
from
paddle.fluid.layer_helper_base
import
LayerHelperBase
from
paddle.fluid.data_feeder
import
convert_dtype
__all__
=
[
'set_default_dtype'
,
'get_default_dtype'
]
def
set_default_dtype
(
d
):
"""
Set default dtype. The default dtype is initially float32
Args:
d(string|np.dtype): the dtype to make the default
Returns:
None.
Examples:
.. code-block:: python
import paddle
paddle.set_default_dtype("float32")
"""
d
=
convert_dtype
(
d
)
LayerHelperBase
.
set_default_dtype
(
d
)
def
get_default_dtype
():
"""
Get the current default dtype. The default dtype is initially float32
Args:
None.
Returns:
The default dtype.
Examples:
.. code-block:: python
import paddle
paddle.get_default_dtype()
"""
return
LayerHelperBase
.
get_default_dtype
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录