From 91bd5835df60fa3cd8c89f4300ee369bd82a5e6a Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Mon, 28 May 2018 10:17:58 +0800 Subject: [PATCH] Fix fill_constant_batch_size_like_op when input is LoDTensor. (#10943) --- .../fill_constant_batch_size_like_op.h | 8 +++++++ .../test_fill_constant_batch_size_like_op.py | 22 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/paddle/fluid/operators/fill_constant_batch_size_like_op.h b/paddle/fluid/operators/fill_constant_batch_size_like_op.h index 2a7df149a9..63ea60678f 100644 --- a/paddle/fluid/operators/fill_constant_batch_size_like_op.h +++ b/paddle/fluid/operators/fill_constant_batch_size_like_op.h @@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* out = ctx.Output("Out"); + auto* in = ctx.Input("Input"); + if (in->lod().size() && ctx.Attr("input_dim_idx") == 0) { + // set the correct batch size for the LoDTensor. + auto odims = out->dims(); + int output_dim_idx = ctx.Attr("output_dim_idx"); + odims[output_dim_idx] = static_cast(in->lod().back().size()) - 1; + out->mutable_data(odims, ctx.GetPlace()); + } out->mutable_data(ctx.GetPlace()); auto value = ctx.Attr("value"); diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py index 66e3e2d51d..533d8ccfac 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_batch_size_like_op.py @@ -50,5 +50,27 @@ class TestFillConstantBatchSizeLikeWhenSecondDimIsBatchSize(OpTest): self.check_output() +class TestFillConstantBatchSizeLikeWithLoDTensor(OpTest): + def setUp(self): + self.op_type = "fill_constant_batch_size_like" + self.inputs = { + 'Input': (np.random.random((31, 28)).astype("float32"), + [[0, 9, 23, 31]]) + } + self.attrs = { + 'value': 3.5, + 'shape': [-1, 16], + 'input_dim_idx': 0, + 'output_dim_idx': 0 + } + + out = np.random.random((3, 16)).astype("float32") + out.fill(3.5) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main() -- GitLab