未验证 提交 d971d5b8 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #14431 from velconia/fix_expand_op_dim_in_compile_time

Fix expand op incorrect infer shape
......@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
out_shape[i] = x_dims[i] * expand_times[i];
}
// set the first dim to -1 in compile time
if (!ctx->IsRuntime()) {
out_shape[0] = x_dims[0];
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
......@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel {
ctx->Attrs().Get<std::vector<int>>("expand_times");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
for (size_t i = 0; i < expand_times.size(); ++i) {
size_t start_pos = 0u;
if (!ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
x_dims[0], out_dims[0],
"The first dimension size of Input(Out@GRAD) should be "
"equal to the crroresponding dimension size of Input(X)");
start_pos = 1u;
}
for (size_t i = start_pos; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension "
......
......@@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
def test_expand_op(self):
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
shape = [-1, 20]
expand_times = [3, 1]
# prepare input/output
x1 = block.var(six.b("x"))
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
x1.set_shape(shape)
out = block.var(six.b("out"))
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
# prepare the operator
sum_op_desc = block.append_op()
sum_op_desc.set_type("expand")
sum_op_desc.set_input("X", ["x"])
sum_op_desc.set_output("Out", ["out"])
sum_op_desc._set_attr('expand_times', expand_times)
sum_op_desc.check_attrs()
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册