未验证 提交 a373aa76 编写于 作者: L lilong12 提交者: GitHub

fix the bug in expand_v2 op (#30984)

* update, test=develop
上级 c4f279fe
......@@ -66,6 +66,9 @@ class ExpandV2Op : public framework::OperatorWithKernel {
out_shape[i] = -1;
} else if (expand_shape[i] == -1) {
out_shape[i] = x_dims[i];
} else if (expand_shape[i] == -2) {
// We use -2 to represent the element in expand_shape is a var.
out_shape[i] = -1;
} else {
PADDLE_ENFORCE_GT(
expand_shape[i], 0,
......@@ -174,7 +177,7 @@ class ExpandV2GradOp : public framework::OperatorWithKernel {
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (expand_shape[i] == -1 || x_dim_vec[i] == -1) {
if (expand_shape[i] < 0 || x_dim_vec[i] == -1) {
continue;
} else {
if (ctx->IsRuntime()) {
......
......@@ -1448,7 +1448,7 @@ def expand(x, shape, name=None):
attrs_expand_shape = []
for idx, shape in enumerate(list_expand_shape):
if isinstance(shape, Variable):
attrs_expand_shape.append(-1)
attrs_expand_shape.append(-2)
else:
attrs_expand_shape.append(shape)
assert shape > 0 or shape == -1, (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册