未验证 提交 10bc9ffc 编写于 作者: J Jiabin Yang 提交者: GitHub

Merge pull request #15518 from JiabinYang/fix/refine_error_message

test=develop, refine_error_message for data type
......@@ -1073,7 +1073,9 @@ Scope* OperatorWithKernel::PrepareData(
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
int data_type = -1;
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) {
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
for (size_t i = 0; i < vars.size(); ++i) {
......@@ -1090,18 +1092,19 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
input.first, i);
int tmp = static_cast<int>(t->type());
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == -1,
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
Type(), data_type, tmp);
Type(), DataTypeToString(data_type), DataTypeToString(tmp));
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<proto::VarType::Type>(data_type);
PADDLE_ENFORCE(data_type != dafault_data_type,
"DataType should be indicated by input");
return data_type;
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
......
......@@ -25,7 +25,8 @@ inline const T* Tensor::data() const {
check_memory_size();
bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d",
DataTypeToString(type_));
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册