Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
87b424e8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
87b424e8
编写于
1月 22, 2018
作者:
Y
Yang Yu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Follow comments
上级
9f731a60
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
9 addition
and
18 deletion
+9
-18
python/paddle/v2/fluid/layers/math_op_patch.py
python/paddle/v2/fluid/layers/math_op_patch.py
+9
-18
未找到文件。
python/paddle/v2/fluid/layers/math_op_patch.py
浏览文件 @
87b424e8
...
@@ -19,7 +19,7 @@ __all__ = ['monkey_patch_variable']
...
@@ -19,7 +19,7 @@ __all__ = ['monkey_patch_variable']
def
monkey_patch_variable
():
def
monkey_patch_variable
():
def
new
_name
():
def
unique_tmp
_name
():
return
unique_name
(
"tmp"
)
return
unique_name
(
"tmp"
)
def
safe_get_dtype
(
var
):
def
safe_get_dtype
(
var
):
...
@@ -29,21 +29,9 @@ def monkey_patch_variable():
...
@@ -29,21 +29,9 @@ def monkey_patch_variable():
raise
ValueError
(
"Cannot get data type from %s"
,
var
.
name
)
raise
ValueError
(
"Cannot get data type from %s"
,
var
.
name
)
return
dtype
return
dtype
def
create_scalar
(
block
,
value
,
dtype
):
value
=
float
(
value
)
tmp_name
=
new_name
()
var
=
block
.
create_var
(
name
=
tmp_name
,
shape
=
[
1
],
dtype
=
dtype
)
block
.
append_op
(
type
=
"fill"
,
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"value"
:
[
value
],
"shape"
:
[
1
],
"dtype"
:
dtype
})
return
var
def
create_tensor
(
block
,
value
,
dtype
,
shape
):
def
create_tensor
(
block
,
value
,
dtype
,
shape
):
value
=
float
(
value
)
value
=
float
(
value
)
tmp_name
=
new
_name
()
tmp_name
=
unique_tmp
_name
()
var
=
block
.
create_var
(
name
=
tmp_name
,
shape
=
shape
,
dtype
=
dtype
)
var
=
block
.
create_var
(
name
=
tmp_name
,
shape
=
shape
,
dtype
=
dtype
)
block
.
append_op
(
block
.
append_op
(
type
=
"fill_constant"
,
type
=
"fill_constant"
,
...
@@ -53,10 +41,13 @@ def monkey_patch_variable():
...
@@ -53,10 +41,13 @@ def monkey_patch_variable():
'value'
:
value
})
'value'
:
value
})
return
var
return
var
def
create_scalar
(
block
,
value
,
dtype
):
return
create_tensor
(
block
,
value
,
dtype
,
shape
=
[
1
])
def
create_tensor_with_batchsize
(
ref_var
,
value
,
dtype
):
def
create_tensor_with_batchsize
(
ref_var
,
value
,
dtype
):
assert
isinstance
(
ref_var
,
Variable
)
assert
isinstance
(
ref_var
,
Variable
)
value
=
float
(
value
)
value
=
float
(
value
)
tmp_name
=
new
_name
()
tmp_name
=
unique_tmp
_name
()
var
=
ref_var
.
block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
var
=
ref_var
.
block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
ref_var
.
block
.
append_op
(
ref_var
.
block
.
append_op
(
type
=
'fill_constant_batch_size_like'
,
type
=
'fill_constant_batch_size_like'
,
...
@@ -68,7 +59,7 @@ def monkey_patch_variable():
...
@@ -68,7 +59,7 @@ def monkey_patch_variable():
def
astype
(
self
,
dtype
):
def
astype
(
self
,
dtype
):
"""
"""
Cast a variable to data type.
Cast a variable to
a specified
data type.
NOTE: The variable must be a Tensor
NOTE: The variable must be a Tensor
Args:
Args:
self(Variable): The source variable
self(Variable): The source variable
...
@@ -77,7 +68,7 @@ def monkey_patch_variable():
...
@@ -77,7 +68,7 @@ def monkey_patch_variable():
Returns:
Returns:
Variable with new dtype
Variable with new dtype
"""
"""
tmp_name
=
new
_name
()
tmp_name
=
unique_tmp
_name
()
out
=
self
.
block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
out
=
self
.
block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
self
.
block
.
append_op
(
self
.
block
.
append_op
(
type
=
"cast"
,
type
=
"cast"
,
...
@@ -120,7 +111,7 @@ def monkey_patch_variable():
...
@@ -120,7 +111,7 @@ def monkey_patch_variable():
self
=
other_var
self
=
other_var
other_var
=
tmp
other_var
=
tmp
tmp_name
=
new
_name
()
tmp_name
=
unique_tmp
_name
()
out
=
self
.
block
.
create_var
(
name
=
tmp_name
,
dtype
=
lhs_dtype
)
out
=
self
.
block
.
create_var
(
name
=
tmp_name
,
dtype
=
lhs_dtype
)
self
.
block
.
append_op
(
self
.
block
.
append_op
(
type
=
op_type
,
type
=
op_type
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录