提交 30147d7f 编写于 作者: M minqiyang

Fix expand op incorrect infer shape

test=develop
上级 9be99b14
......@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
out_shape[i] = x_dims[i] * expand_times[i];
}
// set the first dim to -1 in compile time
if (!ctx->IsRuntime()) {
out_shape[0] = x_dims[0];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
......
......@@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
def test_expand_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
shape = [-1, 20]
expand_times = [3, 1]
# prepare input/output
x1 = block.var(six.b("x"))
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape)
out = block.var(six.b("out"))
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("expand")
sum_op_desc.set_input("X", ["x"])
sum_op_desc.set_output("Out", ["out"])
sum_op_desc._set_attr('expand_times', expand_times)
sum_op_desc.check_attrs()
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册