diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.h b/paddle/fluid/operators/fill_constant_batch_size_like_op.h index 2a7df149a9f4b03676f172da980c927d7fa5e8a4..63ea60678f80708f5a8340edd22588553b9ec139 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.h +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.h @@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* out = ctx.Output("Out"); + auto* in = ctx.Input("Input"); + if (in->lod().size() && ctx.Attr("input_dim_idx") == 0) { + // set the correct batch size for the LoDTensor. + auto odims = out->dims(); + int output_dim_idx = ctx.Attr("output_dim_idx"); + odims[output_dim_idx] = static_cast(in->lod().back().size()) - 1; + out->mutable_data(odims, ctx.GetPlace()); + } out->mutable_data(ctx.GetPlace()); auto value = ctx.Attr("value"); diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py index 66e3e2d51d118756d4881955b4df8eb4d2bbc094..533d8ccfac82a2e298af16181ab16bf7aa3db282 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py @@ -50,5 +50,27 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest): self.check_output() +class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest): + def setUp(self): + self.op_type = "fill_constant_batch_size_like" + self.inputs = { + 'Input': (np.random.random((31, 28)).astype("float32"), + [[0, 9, 23, 31]]) + } + self.attrs = { + 'value': 3.5, + 'shape': [-1, 16], + 'input_dim_idx': 0, + 'output_dim_idx': 0 + } + + out = np.random.random((3, 16)).astype("float32") + out.fill(3.5) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main()