未验证 提交 a863b32e 编写于 作者: 0 0x45f 提交者: GitHub

[Dy2St]Fix cond grad error when handle tensor array (#39689)

* fix cond grad error when handle tensor array

* add UT
上级 05982c10
......@@ -69,6 +69,12 @@ class InferVarTypeContext {
return op_->Inputs().at(name).size();
}
virtual size_t OutputSize(const std::string& name) const {
PADDLE_ENFORCE_NOT_NULL(
op_, platform::errors::PreconditionNotMet("op_ should not be null"));
return op_->Outputs().at(name).size();
}
virtual const std::string& InputVarName(const std::string& name,
const int index = 0) const {
PADDLE_ENFORCE_NOT_NULL(
......
......@@ -272,8 +272,18 @@ class ConditionalBlockGradInferVarType : public framework::VarTypeInference {
// Input is {Tensor, LoDTensorArray}, we need synchronous the Input's
// VarType into Input@GRAD to avoid generating {Tensor, Tensor} as
// Input@GRAD.
auto input_size = ctx->InputSize(ConditionalOp::kInputs);
auto output_size =
ctx->OutputSize(framework::GradVarName(ConditionalOp::kInputs));
PADDLE_ENFORCE_EQ(input_size, output_size,
platform::errors::InvalidArgument(
"input_size and output_size should be equal for "
"conditional_block_grad_op."));
for (size_t i = 0; i < output_size; ++i) {
ctx->SyncTypeAndDataType(ConditionalOp::kInputs,
framework::GradVarName(ConditionalOp::kInputs));
framework::GradVarName(ConditionalOp::kInputs),
i);
}
}
};
......
......@@ -320,10 +320,12 @@ class ListWithCondNet(paddle.nn.Layer):
if index > 0:
res = a[0] * a[0]
y = y + 1
else:
res = a[-1] * a[-1]
y = y - 1
z = a[-1] * res
z = a[-1] * res * y[0]
return z
......@@ -333,7 +335,7 @@ class TestListWithCondGradInferVarType(unittest.TestCase):
x = paddle.to_tensor([2, 3, 4], dtype='float32')
index = paddle.to_tensor([1])
res = net(x, index)
self.assertEqual(res[0], 16.)
self.assertEqual(res[0], 48.)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册