From 003f369bb2664a93b1629c9ca15488fb8463b0ef Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 16 Oct 2019 14:40:36 +0800 Subject: [PATCH] Add IndicateVarDataType interface to block tensor is not initialized problem in OP GetExceptedKernelType (#20044) * add indicate_var_data_type inferface, test=develop * add unittests & polish error message, test=develop * remove needless include, test=develop * extract public function & polish message, test=develop * delete empty var check, test=develop * change data_type to pointer parameter, test=develop * polish details, test=develop --- paddle/fluid/framework/operator.cc | 79 +++++++---- paddle/fluid/framework/operator.h | 5 + paddle/fluid/framework/operator_test.cc | 179 ++++++++++++++++++++++++ paddle/fluid/framework/variable.h | 14 +- paddle/fluid/operators/mul_op.cc | 2 +- paddle/fluid/operators/save_op.cc | 2 +- 6 files changed, 245 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 42e70d9cb0..0552dd535d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1138,40 +1138,65 @@ Scope* OperatorWithKernel::PrepareData( return new_scope; } +void OperatorWithKernel::ParseInputDataType( + const ExecutionContext& ctx, const std::string& name, + proto::VarType::Type* data_type) const { + proto::VarType::Type dafault_data_type = + static_cast(-1); + const std::vector vars = ctx.MultiInputVar(name); + for (size_t i = 0; i < vars.size(); ++i) { + const Variable* var = vars[i]; + if (var != nullptr) { + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &(var->Get().value()); + } + if (t != nullptr) { + PADDLE_ENFORCE_EQ(t->IsInitialized(), true, + "The Tensor in the %s Op's Input Variable %s(%s) is " + "not initialized.", + Type(), name, ctx.Inputs(name).at(i)); + proto::VarType::Type tmp = t->type(); + PADDLE_ENFORCE(tmp == *data_type || *data_type == dafault_data_type, + "The DataType of %s Op's duplicable Variable %s must be " + "consistent. The current variable type is (%s), but the " + "previous variable type is (%s).", + Type(), name, DataTypeToString(tmp), + DataTypeToString(*data_type)); + *data_type = tmp; + } + } + } +} + proto::VarType::Type OperatorWithKernel::IndicateDataType( const ExecutionContext& ctx) const { proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; for (auto& input : this->inputs_) { - const std::vector vars = ctx.MultiInputVar(input.first); - for (size_t i = 0; i < vars.size(); ++i) { - const Variable* var = vars[i]; - if (var != nullptr) { - const Tensor* t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &(var->Get().value()); - } - if (t != nullptr) { - PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized", - input.first, i); - proto::VarType::Type tmp = t->type(); - PADDLE_ENFORCE( - tmp == data_type || data_type == dafault_data_type, - "DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)", - Type(), input.first, DataTypeToString(data_type), - DataTypeToString(tmp)); - data_type = tmp; - } - } - } + ParseInputDataType(ctx, input.first, &data_type); } - PADDLE_ENFORCE(data_type != dafault_data_type, - "DataType should be indicated by input"); + PADDLE_ENFORCE_NE(data_type, dafault_data_type, + "DataType should be indicated by input Variable."); + return data_type; +} + +proto::VarType::Type OperatorWithKernel::IndicateVarDataType( + const ExecutionContext& ctx, const std::string& name) const { + proto::VarType::Type dafault_data_type = + static_cast(-1); + proto::VarType::Type data_type = dafault_data_type; + ParseInputDataType(ctx, name, &data_type); + PADDLE_ENFORCE_NE( + data_type, dafault_data_type, + "The Input Variable(%s) of %s Op used to determine kernel data type " + "is empty or not LoDTensor or SelectedRows.", + name, Type()); return data_type; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5899a14f50..640d4aff5e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -459,6 +459,9 @@ class OperatorWithKernel : public OperatorBase { void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const override; + proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx, + const std::string& name) const; + virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; std::vector* GetKernelConfig(const OpKernelType& key) const; @@ -470,6 +473,8 @@ class OperatorWithKernel : public OperatorBase { const OpKernelType& expected_kernel_type) const; private: + void ParseInputDataType(const ExecutionContext& ctx, const std::string& name, + proto::VarType::Type* type) const; // indicate kernel DataType by input data. By default all input data must be // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index fe4804ac25..aeb1daa4ed 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -315,3 +315,182 @@ TEST(VarNameTest, all) { original_var_name = paddle::framework::GradOriginalVarName(original_var_name); ASSERT_EQ(original_var_name, ""); } + +namespace paddle { +namespace framework { + +class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("LoDTensor", "Input of Tensor type Variable."); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateSelectedRowsDataTypeTestProtoMaker + : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("SelectedRows", "Input of SelectedRows type Variable."); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +class IndicateOtherDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("Other", "Input of Other type Variable"); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +template +class IndicateVarDataTypeKernelTest : public OpKernel { + public: + void Compute(const ExecutionContext& ctx) const {} +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT( + indicate_lod_tensor_data_type_test, + paddle::framework::IndicateLoDTensorDataTypeTest, + paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker); +REGISTER_OP_WITHOUT_GRADIENT( + indicate_selected_rows_data_type_test, + paddle::framework::IndicateSelectedRowsDataTypeTest, + paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker); +REGISTER_OP_WITHOUT_GRADIENT( + indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest, + paddle::framework::IndicateOtherDataTypeTestProtoMaker); + +REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); +REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); +REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); + +TEST(IndicateVarDataTypeTest, lodtensor) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_lod_tensor_data_type_test"); + BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("lodtensor_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE( + ex_msg.find( + "The Tensor in the indicate_lod_tensor_data_type_test Op's " + "Input Variable LoDTensor(lodtensor_1) is not initialized") != + std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(IndicateVarDataTypeTest, selectedrows) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_selected_rows_data_type_test"); + BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("selected_rows_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE( + ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test " + "Op's Input Variable SelectedRows(selected_rows_1) is not " + "initialized") != std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(IndicateVarDataTypeTest, other) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_other_data_type_test"); + BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("lod_tensor_array_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of " + "indicate_other_data_type_test Op used to " + "determine kernel data type " + "is empty or not LoDTensor or SelectedRows") != + std::string::npos); + } + ASSERT_TRUE(caught); +} diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index b9d07da822..5d9633a61d 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -30,9 +30,9 @@ class Variable { static_assert( IsRegisteredVarType(), "Not registered type. Please register T inside var_type_traits.h"); - PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing"); + PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); PADDLE_ENFORCE(holder_->Type() == VarTypeTrait::kId, - "Variable must be type %s, the holding type is %s", + "The Variable type must be %s, but the type it holds is %s.", ToTypeName(VarTypeTrait::kId), ToTypeName(holder_->Type())); return *static_cast(holder_->Ptr()); @@ -45,10 +45,10 @@ class Variable { if (!holder_) { holder_.reset(new PlaceholderImpl()); } else { - PADDLE_ENFORCE(holder_->Type() == VarTypeTrait::kId, - "Variable must be type %s, the holding type is %s", - ToTypeName(VarTypeTrait::kId), - ToTypeName(holder_->Type())); + PADDLE_ENFORCE( + holder_->Type() == VarTypeTrait::kId, + "The Variable type must be %s, but the type it holds is %s.", + ToTypeName(VarTypeTrait::kId), ToTypeName(holder_->Type())); } return static_cast(holder_->Ptr()); } @@ -61,7 +61,7 @@ class Variable { void Clear() { holder_.reset(); } int Type() const { - PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory"); + PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); return holder_->Type(); } diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index cfd55ef8bc..09c544ec02 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -92,7 +92,7 @@ class MulOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index c660bbb8ed..73bac5c2fd 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -31,7 +31,7 @@ class SaveOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; -- GitLab