未验证 提交 ff7e3590 编写于 作者: A Aurelius84 提交者: GitHub

Add ConditionalBlockGradInferVarType (#39585)

上级 d5a0d31a
...@@ -265,6 +265,18 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { ...@@ -265,6 +265,18 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
} }
}; };
class ConditionalBlockGradInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
// NOTE(Aurelius84): VarType of Output is LoDTensor by default. In case of
// Input is {Tensor, LoDTensorArray}, we need synchronous the Input's
// VarType into Input@GRAD to avoid generating {Tensor, Tensor} as
// Input@GRAD.
ctx->SyncTypeAndDataType(ConditionalOp::kInputs,
framework::GradVarName(ConditionalOp::kInputs));
}
};
template <typename T> template <typename T>
class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> { class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
...@@ -300,4 +312,5 @@ REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, ...@@ -300,4 +312,5 @@ REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
ops::ConditionalBlockOpProtoMaker, ops::ConditionalBlockOpProtoMaker,
ops::ConditionalBlockGradMaker<paddle::framework::OpDesc>); ops::ConditionalBlockGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
ops::ConditionalBlockGradInferShape); ops::ConditionalBlockGradInferShape,
ops::ConditionalBlockGradInferVarType);
...@@ -306,5 +306,35 @@ class TestListInForLoopWithSubscript(TestListWithoutControlFlow): ...@@ -306,5 +306,35 @@ class TestListInForLoopWithSubscript(TestListWithoutControlFlow):
self.input = np.random.random((3, 4)).astype('float32') self.input = np.random.random((3, 4)).astype('float32')
class ListWithCondNet(paddle.nn.Layer):
def __init__(self):
super(ListWithCondNet, self).__init__()
@paddle.jit.to_static
def forward(self, x, index):
y = paddle.nn.functional.relu(x)
a = []
for i in y:
a.append(i)
if index > 0:
res = a[0] * a[0]
else:
res = a[-1] * a[-1]
z = a[-1] * res
return z
class TestListWithCondGradInferVarType(unittest.TestCase):
def test_to_static(self):
net = ListWithCondNet()
x = paddle.to_tensor([2, 3, 4], dtype='float32')
index = paddle.to_tensor([1])
res = net(x, index)
self.assertEqual(res[0], 16.)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册