Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ff7e3590
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff7e3590
编写于
2月 16, 2022
作者:
A
Aurelius84
提交者:
GitHub
2月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ConditionalBlockGradInferVarType (#39585)
上级
d5a0d31a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
44 addition
and
1 deletion
+44
-1
paddle/fluid/operators/controlflow/conditional_block_op.cc
paddle/fluid/operators/controlflow/conditional_block_op.cc
+14
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_list.py
+30
-0
未找到文件。
paddle/fluid/operators/controlflow/conditional_block_op.cc
浏览文件 @
ff7e3590
...
...
@@ -265,6 +265,18 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
}
};
class
ConditionalBlockGradInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
// NOTE(Aurelius84): VarType of Output is LoDTensor by default. In case of
// Input is {Tensor, LoDTensorArray}, we need synchronous the Input's
// VarType into Input@GRAD to avoid generating {Tensor, Tensor} as
// Input@GRAD.
ctx
->
SyncTypeAndDataType
(
ConditionalOp
::
kInputs
,
framework
::
GradVarName
(
ConditionalOp
::
kInputs
));
}
};
template
<
typename
T
>
class
ConditionalBlockGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
...
...
@@ -300,4 +312,5 @@ REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
ops
::
ConditionalBlockOpProtoMaker
,
ops
::
ConditionalBlockGradMaker
<
paddle
::
framework
::
OpDesc
>
);
REGISTER_OPERATOR
(
conditional_block_grad
,
ops
::
ConditionalBlockGradOp
,
ops
::
ConditionalBlockGradInferShape
);
ops
::
ConditionalBlockGradInferShape
,
ops
::
ConditionalBlockGradInferVarType
);
python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
浏览文件 @
ff7e3590
...
...
@@ -306,5 +306,35 @@ class TestListInForLoopWithSubscript(TestListWithoutControlFlow):
self
.
input
=
np
.
random
.
random
((
3
,
4
)).
astype
(
'float32'
)
class
ListWithCondNet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
ListWithCondNet
,
self
).
__init__
()
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
,
index
):
y
=
paddle
.
nn
.
functional
.
relu
(
x
)
a
=
[]
for
i
in
y
:
a
.
append
(
i
)
if
index
>
0
:
res
=
a
[
0
]
*
a
[
0
]
else
:
res
=
a
[
-
1
]
*
a
[
-
1
]
z
=
a
[
-
1
]
*
res
return
z
class
TestListWithCondGradInferVarType
(
unittest
.
TestCase
):
def
test_to_static
(
self
):
net
=
ListWithCondNet
()
x
=
paddle
.
to_tensor
([
2
,
3
,
4
],
dtype
=
'float32'
)
index
=
paddle
.
to_tensor
([
1
])
res
=
net
(
x
,
index
)
self
.
assertEqual
(
res
[
0
],
16.
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录