提交 c52f57de 编写于 作者: J JiabinYang

test=develop, refine_error_message for data type

上级 fe8f28c9
...@@ -1073,7 +1073,9 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1073,7 +1073,9 @@ Scope* OperatorWithKernel::PrepareData(
proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
int data_type = -1; proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) { for (auto& input : this->inputs_) {
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first); const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
...@@ -1090,18 +1092,19 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1090,18 +1092,19 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized", PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
input.first, i); input.first, i);
int tmp = static_cast<int>(t->type()); proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)", "DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
Type(), data_type, tmp); Type(), DataTypeToString(data_type), DataTypeToString(tmp));
data_type = tmp; data_type = tmp;
} }
} }
} }
} }
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input"); PADDLE_ENFORCE(data_type != dafault_data_type,
return static_cast<proto::VarType::Type>(data_type); "DataType should be indicated by input");
return data_type;
} }
OpKernelType OperatorWithKernel::GetExpectedKernelType( OpKernelType OperatorWithKernel::GetExpectedKernelType(
......
...@@ -25,7 +25,8 @@ inline const T* Tensor::data() const { ...@@ -25,7 +25,8 @@ inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType; std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_); PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d",
DataTypeToString(type_));
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册