From 4267a81afcab6ccc4d84eab8ffad0dff24fd8d65 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Sat, 13 Apr 2019 13:15:08 +0800 Subject: [PATCH] Correct the lod level of compiled time in lod_reset (#16790) test=develop --- paddle/fluid/framework/var_type_inference.h | 8 +++++-- paddle/fluid/operators/lod_reset_op.cc | 24 ++++++++++++++++--- paddle/fluid/operators/lod_reset_op.h | 2 +- .../fluid/tests/unittests/test_layers.py | 12 +++++++++- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index 2e9c64d3e..66e6ac816 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -45,12 +45,16 @@ class InferVarTypeContext { virtual bool HasInput(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); - return op_->Inputs().count(name) > 0; + auto& inputs = op_->Inputs(); + auto input = inputs.find(name); + return input != inputs.end() && !input->second.empty(); } virtual bool HasOutput(const std::string& name) const { PADDLE_ENFORCE_NOT_NULL(op_); - return op_->Outputs().count(name) > 0; + auto& outputs = op_->Outputs(); + auto output = outputs.find(name); + return output != outputs.end() && !output->second.empty(); } virtual const std::vector& Input(const std::string& name) const { diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index e0ab02cd9..458037c5a 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -30,10 +30,10 @@ class LoDResetOp : public framework::OperatorWithKernel { if (!ctx->HasInput("Y")) { auto level0 = ctx->Attrs().Get>("target_lod"); - PADDLE_ENFORCE_GT(level0.size(), 1, + PADDLE_ENFORCE_GT(level0.size(), 0, "If Input(Y) not provided, the target lod should be " "specified by attribute `target_lod`."); - } else { + } else if (ctx->IsRuntime()) { ctx->ShareLoD("Y", "Out"); } @@ -48,6 +48,23 @@ class LoDResetOp : public framework::OperatorWithKernel { } }; +class LoDResetOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto x_var_name = ctx->Input("X").front(); + auto out_var_name = ctx->Output("Out").front(); + if (ctx->HasInput("Y")) { + auto y_var_name = ctx->Input("Y").front(); + auto y_lod_level = std::max(ctx->GetLoDLevel(y_var_name), 1); + ctx->SetLoDLevel(out_var_name, y_lod_level); + } else { + ctx->SetLoDLevel(out_var_name, 1); + } + ctx->SetDataType(out_var_name, ctx->GetDataType(x_var_name)); + ctx->SetType(out_var_name, paddle::framework::proto::VarType::LOD_TENSOR); + } +}; + class LoDResetOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -177,9 +194,10 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LoDResetGradNoNeedBufferVarInference, namespace ops = paddle::operators; REGISTER_OPERATOR(lod_reset, ops::LoDResetOp, ops::LoDResetOpMaker, - ops::LoDResetGradDescMaker); + ops::LoDResetGradDescMaker, ops::LoDResetOpVarTypeInference); REGISTER_OPERATOR(lod_reset_grad, ops::LoDResetGradOp, ops::LoDResetGradNoNeedBufferVarInference); + REGISTER_OP_CPU_KERNEL( lod_reset, ops::LoDResetKernel, ops::LoDResetKernel, diff --git a/paddle/fluid/operators/lod_reset_op.h b/paddle/fluid/operators/lod_reset_op.h index d36aa0ce0..1c2f0b0ac 100644 --- a/paddle/fluid/operators/lod_reset_op.h +++ b/paddle/fluid/operators/lod_reset_op.h @@ -63,7 +63,7 @@ class LoDResetKernel : public framework::OpKernel { "Target LoD should be a vector end with the " "first dimension of Input(X)."); for (size_t i = 0; i < level0.size() - 1; ++i) { - PADDLE_ENFORCE(level0[i + 1] > level0[i], + PADDLE_ENFORCE(level0[i + 1] >= level0[i], "Target LoD should be an ascending vector."); } diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 6630fb26a..91f8bc5fd 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1759,10 +1759,20 @@ class TestBook(LayerTest): def test_lod_reset(self): # TODO(minqiyang): dygraph do not support lod now with self.static_graph(): + # case 1 x = layers.data(name='x', shape=[10], dtype='float32') y = layers.data( name='y', shape=[10, 20], dtype='float32', lod_level=2) - return (layers.lod_reset(x=x, y=y)) + z = layers.lod_reset(x=x, y=y) + self.assertTrue(z.lod_level == 2) + # case 2 + lod_tensor_in = layers.data(name='lod_in', shape=[1], dtype='int64') + z = layers.lod_reset(x=x, y=lod_tensor_in) + self.assertTrue(z.lod_level == 1) + # case 3 + z = layers.lod_reset(x=x, target_lod=[1, 2, 3]) + self.assertTrue(z.lod_level == 1) + return z def test_affine_grid(self): with self.static_graph(): -- GitLab