未验证 提交 003f369b 编写于 作者: C Chen Weihang 提交者: GitHub

Add IndicateVarDataType interface to block tensor is not initialized problem...

Add IndicateVarDataType interface to block tensor is not initialized problem in OP GetExceptedKernelType (#20044)

* add indicate_var_data_type inferface, test=develop

* add unittests & polish error message, test=develop

* remove needless include, test=develop

* extract public function & polish message, test=develop

* delete empty var check, test=develop

* change data_type to pointer parameter, test=develop

* polish details, test=develop
上级 dfa23925
......@@ -1138,40 +1138,65 @@ Scope* OperatorWithKernel::PrepareData(
return new_scope;
}
void OperatorWithKernel::ParseInputDataType(
const ExecutionContext& ctx, const std::string& name,
proto::VarType::Type* data_type) const {
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
const std::vector<const Variable*> vars = ctx.MultiInputVar(name);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, ctx.Inputs(name).at(i));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(tmp == *data_type || *data_type == dafault_data_type,
"The DataType of %s Op's duplicable Variable %s must be "
"consistent. The current variable type is (%s), but the "
"previous variable type is (%s).",
Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type));
*data_type = tmp;
}
}
}
}
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) {
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized",
input.first, i);
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)",
Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp;
}
}
}
ParseInputDataType(ctx, input.first, &data_type);
}
PADDLE_ENFORCE(data_type != dafault_data_type,
"DataType should be indicated by input");
PADDLE_ENFORCE_NE(data_type, dafault_data_type,
"DataType should be indicated by input Variable.");
return data_type;
}
proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
const ExecutionContext& ctx, const std::string& name) const {
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
ParseInputDataType(ctx, name, &data_type);
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
"The Input Variable(%s) of %s Op used to determine kernel data type "
"is empty or not LoDTensor or SelectedRows.",
name, Type());
return data_type;
}
......
......@@ -459,6 +459,9 @@ class OperatorWithKernel : public OperatorBase {
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
const std::string& name) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
......@@ -470,6 +473,8 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const;
private:
void ParseInputDataType(const ExecutionContext& ctx, const std::string& name,
proto::VarType::Type* type) const;
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
......
......@@ -315,3 +315,182 @@ TEST(VarNameTest, all) {
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "");
}
namespace paddle {
namespace framework {
class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("LoDTensor", "Input of Tensor type Variable.");
AddComment("This Op is only for IndicateVarDataType inferface test.");
}
};
class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndicateSelectedRowsDataTypeTestProtoMaker
: public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("SelectedRows", "Input of SelectedRows type Variable.");
AddComment("This Op is only for IndicateVarDataType inferface test.");
}
};
class IndicateOtherDataTypeTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("Other", "Input of Other type Variable");
AddComment("This Op is only for IndicateVarDataType inferface test.");
}
};
template <typename DeviceContext, typename T>
class IndicateVarDataTypeKernelTest : public OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const {}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
indicate_lod_tensor_data_type_test,
paddle::framework::IndicateLoDTensorDataTypeTest,
paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
indicate_selected_rows_data_type_test,
paddle::framework::IndicateSelectedRowsDataTypeTest,
paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest,
paddle::framework::IndicateOtherDataTypeTestProtoMaker);
REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test,
paddle::framework::IndicateVarDataTypeKernelTest<
paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test,
paddle::framework::IndicateVarDataTypeKernelTest<
paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test,
paddle::framework::IndicateVarDataTypeKernelTest<
paddle::platform::CPUDeviceContext, int>);
TEST(IndicateVarDataTypeTest, lodtensor) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_lod_tensor_data_type_test");
BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lodtensor_1");
var->GetMutable<paddle::framework::LoDTensor>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find(
"The Tensor in the indicate_lod_tensor_data_type_test Op's "
"Input Variable LoDTensor(lodtensor_1) is not initialized") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, selectedrows) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_selected_rows_data_type_test");
BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("selected_rows_1");
var->GetMutable<paddle::framework::SelectedRows>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test "
"Op's Input Variable SelectedRows(selected_rows_1) is not "
"initialized") != std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, other) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_other_data_type_test");
BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lod_tensor_array_1");
var->GetMutable<paddle::framework::LoDTensorArray>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of "
"indicate_other_data_type_test Op used to "
"determine kernel data type "
"is empty or not LoDTensor or SelectedRows") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}
......@@ -30,9 +30,9 @@ class Variable {
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
"The Variable type must be %s, but the type it holds is %s.",
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr());
......@@ -45,10 +45,10 @@ class Variable {
if (!holder_) {
holder_.reset(new PlaceholderImpl<T>());
} else {
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
PADDLE_ENFORCE(
holder_->Type() == VarTypeTrait<T>::kId,
"The Variable type must be %s, but the type it holds is %s.",
ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type()));
}
return static_cast<T*>(holder_->Ptr());
}
......@@ -61,7 +61,7 @@ class Variable {
void Clear() { holder_.reset(); }
int Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
return holder_->Type();
}
......
......@@ -92,7 +92,7 @@ class MulOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
......
......@@ -31,7 +31,7 @@ class SaveOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册