Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
686f0ecb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
686f0ecb
编写于
12月 11, 2019
作者:
M
mapingshuo
提交者:
GitHub
12月 11, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add `no_need_buffer_slots` interface to pybind (#21575)
* add no_need_buffer_slots interface to pybind
上级
6828f368
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
86 addition
and
1 deletion
+86
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+14
-1
python/paddle/fluid/tests/unittests/test_infer_no_need_buffer_slots.py
.../fluid/tests/unittests/test_infer_no_need_buffer_slots.py
+72
-0
未找到文件。
paddle/fluid/pybind/pybind.cc
浏览文件 @
686f0ecb
...
...
@@ -1103,7 +1103,20 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"has_infer_inplace"
,
[](
const
std
::
string
op_type
)
{
return
framework
::
OpInfoMap
::
Instance
().
Get
(
op_type
).
HasInferInplace
();
});
m
.
def
(
"infer_no_need_buffer_slots"
,
[](
const
std
::
string
op_type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
{
auto
infer_func
=
framework
::
OpInfoMap
::
Instance
()
.
Get
(
op_type
)
.
NoNeedBufferVarsInferer
();
if
(
infer_func
)
{
return
infer_func
(
inputs
,
outputs
,
attrs
);
}
else
{
std
::
unordered_set
<
std
::
string
>
empty
=
{};
return
empty
;
}
});
m
.
def
(
"prune"
,
[](
const
ProgramDesc
&
origin
,
const
std
::
set
<
std
::
string
>
&
feeded_var_names
,
const
std
::
vector
<
std
::
array
<
size_t
,
2
>>
&
targets
)
{
...
...
python/paddle/fluid/tests/unittests/test_infer_no_need_buffer_slots.py
0 → 100644
浏览文件 @
686f0ecb
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.framework
as
framework
import
paddle.compat
as
cpt
import
paddle.fluid.core
as
core
class
TestInferNoNeedBufferSlots
(
unittest
.
TestCase
):
def
net
(
self
):
x1
=
fluid
.
default_main_program
().
global_block
().
create_var
(
dtype
=
"float32"
,
shape
=
[
1
],
lod_level
=
0
,
name
=
"x1"
)
x2
=
fluid
.
default_main_program
().
global_block
().
create_var
(
dtype
=
"float32"
,
shape
=
[
1
],
lod_level
=
0
,
name
=
"x2"
)
x
=
fluid
.
layers
.
elementwise_add
(
x1
,
x2
)
return
x
def
test_infer_no_need_buffer_slots
(
self
):
program
=
framework
.
Program
()
startup_program
=
framework
.
Program
()
with
fluid
.
program_guard
(
program
,
startup_program
):
loss
=
self
.
net
()
sgd
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.01
)
sgd
.
minimize
(
loss
)
block
=
program
.
global_block
()
for
idx
,
op
in
enumerate
(
block
.
ops
):
op_desc
=
op
.
desc
inputs
=
{}
for
input_name
in
op_desc
.
input_names
():
inputs
[
input_name
]
=
op_desc
.
input
(
input_name
)
outputs
=
{}
for
output_name
in
op_desc
.
output_names
():
outputs
[
output_name
]
=
op_desc
.
output
(
output_name
)
attrs
=
{}
for
attr_name
in
op_desc
.
attr_names
():
attrs
[
attr_name
]
=
op_desc
.
attr
(
attr_name
)
if
idx
==
0
:
# elementwise_add op
self
.
assertEqual
(
core
.
infer_no_need_buffer_slots
(
op
.
type
,
inputs
,
outputs
,
attrs
),
set
([]))
elif
idx
==
1
:
# fill constant op
self
.
assertEqual
(
core
.
infer_no_need_buffer_slots
(
op
.
type
,
inputs
,
outputs
,
attrs
),
set
([]))
else
:
# elementwise_add_grad op
self
.
assertEqual
(
core
.
infer_no_need_buffer_slots
(
op
.
type
,
inputs
,
outputs
,
attrs
),
set
([
'Y'
,
'X'
]))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录