From 35f17ae28f292a602274a93d0398286c3b0f1afb Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 11 Nov 2019 14:44:20 +0800 Subject: [PATCH] Add the check of lod_level between compile-time and runtime. (#20961) * Add the check of lod_level between compile-time and runtime. test=develop * Fix bug in check_compile_vs_runtime. test=develop * Fix the check of output when it is dispensiable or intermediate. test=develop * Share lod of x to out in match_matrix_tensor op in compile-time. * Implement GetLoDLevel in InferShapeContext. * Set the default value of check_compile_vs_runtime to False and enable it in test_sequence_pad_op. test=develop * Enable check_compile_vs_runtime in test_match_matrix_tensor. * Add the implementation of SetLoDLevel in InferShapeContext. * Remove the implementation of IncreaseLoDLevel and call Get/SetLoDLevel instead. * Remove the implementation of DecreaseLoDLevel and call Set/GetLoDLevel instead. * Refine some ops and unittests. test=develop * Fix a typo. test=develop * Remove the check of var type, and change int to int32_t. test=develop * Add unittest for Get/SetLoDLevel. test=develop --- paddle/fluid/framework/op_desc.cc | 58 ++++------ paddle/fluid/framework/operator.cc | 11 +- paddle/fluid/framework/operator_test.cc | 105 +++++++++++++++++- paddle/fluid/framework/shape_inference.h | 7 +- .../fluid/operators/array_to_lod_tensor_op.cc | 4 +- .../fluid/operators/lod_tensor_to_array_op.cc | 4 +- .../fluid/operators/match_matrix_tensor_op.cc | 1 + .../operators/sequence_ops/sequence_pad_op.cc | 6 +- .../sequence_ops/sequence_pool_op.cc | 6 +- .../sequence_ops/sequence_unpad_op.cc | 4 +- .../paddle/fluid/tests/unittests/op_test.py | 63 ++++++++++- .../unittests/test_match_matrix_tensor_op.py | 2 +- .../tests/unittests/test_sequence_pad_op.py | 2 +- .../tests/unittests/test_sequence_unpad_op.py | 2 +- 14 files changed, 207 insertions(+), 68 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 4d20ed21ed9..4da50d6578f 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -92,45 +92,35 @@ class CompileTimeInferShapeContext : public InferShapeContext { out_var->SetLoDLevel(in_var->GetLoDLevel()); } - void DecreaseLoDLevel(const std::string &in, const std::string &out, - size_t i = 0, size_t j = 0) const override { - // When in is a LoDTensor and out is a LoDTensorArray, there may need to - // decrease the lod_level. - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName, - "The %s[%d] is @EMPTY@", in, i); - PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName, - "The %s[%d] is @EMPTY@", out, j); + int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override { + PADDLE_ENFORCE_LT(i, Inputs(in).size(), + "Input %s of operator %s only has %d elements.", in, + op_.Type(), Inputs(in).size()); + PADDLE_ENFORCE_NE(Inputs(in)[i], framework::kEmptyVarName, + "Input %s[%d] of operator %s is @EMPTY@", in, op_.Type(), + i); auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); - auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); - PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR, - "The input %s should be LoDTensor.", in_var->Name()); - PADDLE_ENFORCE_EQ(out_var->GetType(), proto::VarType::LOD_TENSOR_ARRAY, - "The output %s should be LoDTensorArray.", - out_var->Name()); - if (in_var->GetLoDLevel() > 0) { - out_var->SetLoDLevel(in_var->GetLoDLevel() - 1); - } + PADDLE_ENFORCE_NOT_NULL( + in_var, "Input %s[%d] of operator %s should not be nullptr.", in, + op_.Type(), i); + return in_var->GetLoDLevel(); } - void IncreaseLoDLevel(const std::string &in, const std::string &out, - size_t i = 0, size_t j = 0) const override { - // When in is a LoDTensorArray and out is a LoDTensor, there may need to - // increase the lod_level. - PADDLE_ENFORCE_LT(i, Inputs(in).size()); - PADDLE_ENFORCE_LT(j, Outputs(out).size()); - PADDLE_ENFORCE_NE(Inputs(in)[i], framework::kEmptyVarName, - "The %s[%d] is @EMPTY@", in, i); + void SetLoDLevel(const std::string &out, int32_t lod_level, + size_t j = 0) const override { + PADDLE_ENFORCE_LT(j, Outputs(out).size(), + "Output %s of operator %s only has %d elements.", out, + op_.Type(), Outputs(out).size()); PADDLE_ENFORCE_NE(Outputs(out)[j], framework::kEmptyVarName, - "The %s[%d] is @EMPTY@", out, j); - auto *in_var = block_.FindVarRecursive(Inputs(in)[i]); + "Output %s[%d] of operator %s is @EMPTY@", out, + op_.Type(), j); auto *out_var = block_.FindVarRecursive(Outputs(out)[j]); - PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR_ARRAY, - "The input %s should be LoDTensorArray.", in_var->Name()); - PADDLE_ENFORCE_EQ(out_var->GetType(), proto::VarType::LOD_TENSOR, - "The output %s should be LoDTensor.", out_var->Name()); - out_var->SetLoDLevel(in_var->GetLoDLevel() + 1); + PADDLE_ENFORCE_NOT_NULL( + out_var, "Output %s[%d] of operator %s should not be nullptr.", out, + op_.Type(), j); + if (lod_level >= 0) { + out_var->SetLoDLevel(lod_level); + } } std::vector GetInputVarPtrs( diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8785b618dc8..424a8fb7d54 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -657,18 +657,17 @@ class RuntimeInferShapeContext : public InferShapeContext { out_tensor->set_layout(in_tensor.layout()); } - void DecreaseLoDLevel(const std::string& in, const std::string& out, - size_t i = 0, size_t j = 0) const override { + int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { PADDLE_THROW( - "DecreaseLoDLevel is only used in compile time. The calculation of " + "GetLoDLevel is only used in compile time. The calculation of " "output's actual lod is different among operators so that should be " "set in the runtime kernel."); } - void IncreaseLoDLevel(const std::string& in, const std::string& out, - size_t i = 0, size_t j = 0) const override { + void SetLoDLevel(const std::string& out, int32_t lod_level, + size_t j = 0) const override { PADDLE_THROW( - "IncreaseLoDLevel is only used in compile time. The calculation of " + "SetLoDLevel is only used in compile time. The calculation of " "output's actual lod is different among operators so that should be " "set in the runtime kernel."); } diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index aeb1daa4ed9..2a35a0392c6 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -331,6 +331,7 @@ class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { return framework::OpKernelType(data_type, ctx.device_context()); } }; + class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { public: void Make() { @@ -382,7 +383,7 @@ class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { }; template -class IndicateVarDataTypeKernelTest : public OpKernel { +class EmptyTestKernel : public OpKernel { public: void Compute(const ExecutionContext& ctx) const {} }; @@ -403,13 +404,13 @@ REGISTER_OP_WITHOUT_GRADIENT( paddle::framework::IndicateOtherDataTypeTestProtoMaker); REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test, - paddle::framework::IndicateVarDataTypeKernelTest< + paddle::framework::EmptyTestKernel< paddle::platform::CPUDeviceContext, int>); REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test, - paddle::framework::IndicateVarDataTypeKernelTest< + paddle::framework::EmptyTestKernel< paddle::platform::CPUDeviceContext, int>); REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test, - paddle::framework::IndicateVarDataTypeKernelTest< + paddle::framework::EmptyTestKernel< paddle::platform::CPUDeviceContext, int>); TEST(IndicateVarDataTypeTest, lodtensor) { @@ -494,3 +495,99 @@ TEST(IndicateVarDataTypeTest, other) { } ASSERT_TRUE(caught); } + +namespace paddle { +namespace framework { + +class GetLoDLevelTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, + "Input(X) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) should not be null."); + PADDLE_ENFORCE_GT(ctx->GetLoDLevel("X"), 0, + "The LoD level Input(X) should be larger than 0."); + } +}; + +class SetLoDLevelTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true, + "Input(X) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) should not be null."); + ctx->SetLoDLevel("Out", 1); + } +}; + +class GetSetLoDLevelTestMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(LoDTensor) Input Variable."); + AddOutput("Out", "(LoDTensor) Output Variable."); + AddComment("This Op is only for Get/SetLoDLevel inferface test."); + } +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT(get_lod_level_test, + paddle::framework::GetLoDLevelTest, + paddle::framework::GetSetLoDLevelTestMaker); +REGISTER_OP_CPU_KERNEL(get_lod_level_test, + paddle::framework::EmptyTestKernel< + paddle::platform::CPUDeviceContext, float>); + +REGISTER_OP_WITHOUT_GRADIENT(set_lod_level_test, + paddle::framework::SetLoDLevelTest, + paddle::framework::GetSetLoDLevelTestMaker); +REGISTER_OP_CPU_KERNEL(set_lod_level_test, + paddle::framework::EmptyTestKernel< + paddle::platform::CPUDeviceContext, float>); + +void SetGetLoDLevelTestMain(std::string op_type) { + paddle::framework::InitDevices(false, {}); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type(op_type); + BuildVar("X", {"x.0"}, op_desc.add_inputs()); + BuildVar("Out", {"out.0"}, op_desc.add_outputs()); + + paddle::platform::CPUPlace place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* x_var = scope.Var("x.0"); + auto* x = x_var->GetMutable(); + x->mutable_data(paddle::framework::make_ddim({64}), place); + auto* out_var = scope.Var("out.0"); + out_var->GetMutable(); + + bool caught = false; + std::string err_str = + (op_type == "get_lod_level_test") ? "GetLoDLevel" : "SetLoDLevel"; + err_str += + " is only used in compile time. The calculation of output's actual lod " + "is different among operators so that should be set in the runtime " + "kernel."; + try { + op->Run(scope, place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE(ex_msg.find(err_str) != std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(GetLoDLevelTest, base) { SetGetLoDLevelTestMain("get_lod_level_test"); } + +TEST(SetLoDLevelTest, base) { SetGetLoDLevelTestMain("set_lod_level_test"); } diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 533422cacc1..73dd621a27e 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -65,11 +65,10 @@ class InferShapeContext { virtual void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t j = 0) const = 0; - virtual void DecreaseLoDLevel(const std::string &in, const std::string &out, - size_t i = 0, size_t j = 0) const = 0; + virtual int32_t GetLoDLevel(const std::string &in, size_t i = 0) const = 0; - virtual void IncreaseLoDLevel(const std::string &in, const std::string &out, - size_t i = 0, size_t j = 0) const = 0; + virtual void SetLoDLevel(const std::string &out, int32_t lod_level, + size_t j = 0) const = 0; virtual bool IsRuntime() const = 0; diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index 0f258f4a5e7..ea65cf4e62c 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -199,13 +199,13 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase { context->SetOutputDim("Out", context->GetInputDim("X")); // The output LoDTensor's lod_level should be input X's lod_level + 1. - // For compile-time, we call IncreaseLoDLevel to set output's lod_level. + // For compile-time, we call SetLoDLevel to set output's lod_level. // For runtime, output LoDTensor's lod is determined by input X's lod and // the level specified by input RandTable. // We cannot get X's detail lod and RankTable's level in this function, so // leave this work to the detail kernel implementation. if (!context->IsRuntime()) { - context->IncreaseLoDLevel("X", /*->*/ "Out"); + context->SetLoDLevel("Out", context->GetLoDLevel("X") + 1); } } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index fbabd358306..224069ec42e 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -206,13 +206,13 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { context->SetOutputDim("Out", x_dim); // The output LoDTensor's lod_level should be input X's lod_level - 1. - // For compile time, we call DecreaseLoDLevel to set output's lod_level. + // For compile time, we call SetLoDLevel to set output's lod_level. // For runtime, output LoDTensor's lod is determined by input X's lod and // the level specified by input RandTable. // We cannot get X's detail lod and RankTable's level in this function, so // leave this work to the detail kernel implementation. if (!context->IsRuntime()) { - context->DecreaseLoDLevel("X", /*->*/ "Out"); + context->SetLoDLevel("Out", context->GetLoDLevel("X") - 1); } } }; diff --git a/paddle/fluid/operators/match_matrix_tensor_op.cc b/paddle/fluid/operators/match_matrix_tensor_op.cc index 2aeb68ed901..3f28895e67a 100644 --- a/paddle/fluid/operators/match_matrix_tensor_op.cc +++ b/paddle/fluid/operators/match_matrix_tensor_op.cc @@ -101,6 +101,7 @@ void MatchMatrixTensorOP::InferShape(framework::InferShapeContext* ctx) const { framework::VarDesc* y_desc = boost::get(ctx->GetInputVarPtrs("Y")[0]); PADDLE_ENFORCE_GE(y_desc->GetLoDLevel(), 1); + ctx->ShareLoD("X", "Out"); } std::vector out_dims_vec{out_dim_0}; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index 65b8839a2b1..3c6d36a0a61 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -76,9 +76,9 @@ class SequencePadOp : public framework::OperatorWithKernel { if (padded_length == -1) { padded_length = 1; } - framework::VarDesc* x_desc = - boost::get(ctx->GetInputVarPtrs("X")[0]); - PADDLE_ENFORCE_GE(x_desc->GetLoDLevel(), 1); + PADDLE_ENFORCE_GT( + ctx->GetLoDLevel("X"), 0, + "The LoD level Input(X) of sequence_pad should be larger than 0."); } std::vector out_dims_vec{out_dim_0, padded_length}; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index d2459570c92..1ece8bf937a 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -31,11 +31,9 @@ class SequencePoolOp : public framework::OperatorWithKernel { if (!ctx->IsRuntime()) { // Check the lod_level for compile-time. - framework::VarDesc* x_desc = - boost::get(ctx->GetInputVarPtrs("X")[0]); PADDLE_ENFORCE_GT( - x_desc->GetLoDLevel(), 0, - "The LoD level Input(X) of sequence_pool should be larger than 0"); + ctx->GetLoDLevel("X"), 0, + "The LoD level Input(X) of sequence_pool should be larger than 0."); } ctx->SetOutputDim("Out", ctx->GetInputDim("X")); diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index dbb01a5c367..2b3c5a09406 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -58,9 +58,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", framework::make_ddim(out_dims_vec)); if (!ctx->IsRuntime()) { - framework::VarDesc* out_desc = - boost::get(ctx->GetOutputVarPtrs("Out")[0]); - out_desc->SetLoDLevel(1); + ctx->SetLoDLevel("Out", 1); } } diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index b78679f0e4b..1cfeed7cb6d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -369,6 +369,8 @@ class OpTest(unittest.TestCase): feed=feed_map, fetch_list=fetch_list, return_numpy=False) + self.op = op + self.program = original_program if for_inplace_test: return outs, fetch_list, feed_map, original_program, op.desc else: @@ -833,6 +835,54 @@ class OpTest(unittest.TestCase): self.check_inplace_output_with_place( place, no_check_set=no_check_set, inplace_atol=inplace_atol) + if check_dygraph: + return outs, dygraph_outs, fetch_list + else: + return outs, fetch_list + + def check_compile_vs_runtime(self, fetch_list, fetch_outs): + def find_fetch_index(target_name, fetch_list): + found = [ + i for i, var_name in enumerate(fetch_list) + if var_name == target_name + ] + if len(found) == 0: + return -1 + else: + self.assertTrue( + len(found) == 1, + "Found {} {}".format(len(found), target_name)) + return found[0] + + for name in self.op.desc.output_names(): + var_names = self.op.desc.output(name) + for var_name in var_names: + i = find_fetch_index(var_name, fetch_list) + if i == -1: + # The output is dispensiable or intermediate. + break + out = fetch_outs[i] + if isinstance(out, core.LoDTensor): + lod_level_runtime = len(out.lod()) + else: + if isinstance(out, core.LoDTensorArray): + warnings.warn( + "The check of LoDTensorArray's lod_level is not implemented now!" + ) + lod_level_runtime = 0 + + var = self.program.global_block().var(var_name) + if var.type == core.VarDesc.VarType.LOD_TENSOR: + lod_level_compile = var.lod_level + else: + lod_level_compile = 0 + self.assertEqual( + lod_level_compile, lod_level_runtime, + "The lod_level of Output (" + name + + ") is different between compile-time and runtime (" + + str(lod_level_compile) + " vs " + str(lod_level_runtime) + + ")") + def _get_places(self): if self.dtype == np.float16: if core.is_compiled_with_cuda() and core.op_support_gpu( @@ -860,11 +910,18 @@ class OpTest(unittest.TestCase): no_check_set=None, equal_nan=False, check_dygraph=False, - inplace_atol=None): + inplace_atol=None, + check_compile_vs_runtime=False): places = self._get_places() for place in places: - self.check_output_with_place(place, atol, no_check_set, equal_nan, - check_dygraph) + res = self.check_output_with_place(place, atol, no_check_set, + equal_nan, check_dygraph) + if check_dygraph: + outs, dygraph_outs, fetch_list = res + else: + outs, fetch_list = res + if check_compile_vs_runtime: + self.check_compile_vs_runtime(fetch_list, outs) def check_output_customized(self, checker): places = self._get_places() diff --git a/python/paddle/fluid/tests/unittests/test_match_matrix_tensor_op.py b/python/paddle/fluid/tests/unittests/test_match_matrix_tensor_op.py index 49f630ba8f6..9487f6ed1d3 100644 --- a/python/paddle/fluid/tests/unittests/test_match_matrix_tensor_op.py +++ b/python/paddle/fluid/tests/unittests/test_match_matrix_tensor_op.py @@ -71,7 +71,7 @@ class TestMatchMatrixTensorOp(OpTest): self.outputs = {'Out': (out, out_lod), 'Tmp': tmp} def test_check_output(self): - self.check_output() + self.check_output(check_compile_vs_runtime=True) def test_check_grad(self): self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) diff --git a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py index 01ed53471fe..1791df350c1 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_pad_op.py @@ -72,7 +72,7 @@ class TestSequencePadOp(OpTest): self.compute() def test_check_output(self): - self.check_output() + self.check_output(check_compile_vs_runtime=True) def test_check_grad(self): self.check_grad(["X"], "Out") diff --git a/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py index 19ef00ba83c..ec63b87bfae 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_unpad_op.py @@ -48,7 +48,7 @@ class TestSequenceUnpadOp(OpTest): self.compute() def test_check_output(self): - self.check_output() + self.check_output(check_compile_vs_runtime=True) def test_check_grad(self): self.check_grad(["X"], "Out") -- GitLab