diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 5ad0ec251328cc1ba580026bb47bf05316e7dc77..40f7c1c54c861abebc84428f55e2769ac8969f0f 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel { out_shape[i] = x_dims[i] * expand_times[i]; } + // set the first dim to -1 in compile time + if (!ctx->IsRuntime()) { + out_shape[0] = x_dims[0]; + } + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); if (out_shape[0] == x_dims[0]) { ctx->ShareLoD("X", "Out"); @@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel { ctx->Attrs().Get>("expand_times"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - for (size_t i = 0; i < expand_times.size(); ++i) { + size_t start_pos = 0u; + if (!ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + x_dims[0], out_dims[0], + "The first dimension size of Input(Out@GRAD) should be " + "equal to the crroresponding dimension size of Input(X)"); + start_pos = 1u; + } + + for (size_t i = start_pos; i < expand_times.size(); ++i) { PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], "Each dimension size of Input(Out@GRAD) should be " "equal to multiplication of crroresponding dimension " diff --git a/python/paddle/fluid/tests/unittests/test_infer_shape.py b/python/paddle/fluid/tests/unittests/test_infer_shape.py index fdff22cacc28731a91ff4fd17407bd9edbdd9d8b..9d5e064e6adabe09094350db2976f83d835520eb 100644 --- a/python/paddle/fluid/tests/unittests/test_infer_shape.py +++ b/python/paddle/fluid/tests/unittests/test_infer_shape.py @@ -83,6 +83,34 @@ class TestInferShape(unittest.TestCase): mul_op_desc.infer_shape(block) self.assertEqual(out.shape(), [x_shape[0], y_shape[1]]) + def test_expand_op(self): + prog = core.ProgramDesc() + self.assertIsNotNone(prog) + block = prog.block(0) + self.assertIsNotNone(block) + + shape = [-1, 20] + expand_times = [3, 1] + + # prepare input/output + x1 = block.var(six.b("x")) + x1.set_type(core.VarDesc.VarType.LOD_TENSOR) + x1.set_shape(shape) + + out = block.var(six.b("out")) + out.set_type(core.VarDesc.VarType.LOD_TENSOR) + + # prepare the operator + sum_op_desc = block.append_op() + sum_op_desc.set_type("expand") + sum_op_desc.set_input("X", ["x"]) + sum_op_desc.set_output("Out", ["out"]) + sum_op_desc._set_attr('expand_times', expand_times) + + sum_op_desc.check_attrs() + sum_op_desc.infer_shape(block) + self.assertEqual(out.shape(), shape) + if __name__ == '__main__': unittest.main()