未验证 提交 dfb3ae1b 编写于 作者: C Chen Weihang 提交者: GitHub

Polish some error message in framework holder - Part 1 (#25509)

* polish some error message in framework, test=develop

* fix unittest error, test=develop

* replace PADDLE_ENFORCE, test=develop

* polish details based review comment, test=develop
上级 1ab4101d
......@@ -63,7 +63,8 @@ class Array {
HOSTDEVICE inline const T &at(size_t i) const {
#ifndef __CUDA_ARCH__
PADDLE_ENFORCE_LT(i, N, "Array index out of bounds");
PADDLE_ENFORCE_LT(
i, N, platform::errors::OutOfRange("Array index out of bounds."));
#endif
return (*this)[i];
}
......@@ -106,7 +107,7 @@ class Array<T, 0> {
static T obj();
return obj;
#else
PADDLE_THROW("Array<T, 0> has no element");
PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
#endif
}
......@@ -115,7 +116,7 @@ class Array<T, 0> {
static const T obj();
return obj;
#else
PADDLE_THROW("Array<T, 0> has no element");
PADDLE_THROW(platform::errors::Unavailable("Array<T, 0> has no element."));
#endif
}
......
......@@ -77,11 +77,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto var_name : fetch_var_names) {
auto var_desc = block.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var_desc, platform::errors::NotFound("%s is not found.", var_name));
var_desc, platform::errors::NotFound(
"Variable %s is not found in main program.", var_name));
auto shapes = var_desc->GetShape();
PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1,
"var %s: Fetched var has wrong shape, "
"only variables with the last dimension size 1 supported",
PADDLE_ENFORCE_EQ(shapes[shapes.size() - 1], 1,
platform::errors::InvalidArgument(
"Fetched variable %s has wrong shape, "
"only variables whose last dimension is 1 are supported",
var_name);
}
......@@ -95,7 +97,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
actual_thread_num_ = thread_num;
int file_cnt = filelist.size();
PADDLE_ENFORCE_GT(file_cnt, 0,
platform::errors::NotFound("Input file list is empty"));
platform::errors::NotFound("Input file list is empty."));
if (actual_thread_num_ > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
......
......@@ -72,7 +72,8 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
return val;
}
default:
PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
PADDLE_THROW(platform::errors::Unavailable("Unsupport attribute type %d.",
attr_desc.type()));
}
return boost::blank();
}
......
......@@ -37,9 +37,10 @@ struct ExtractAttribute {
try {
attr_value = &boost::get<T>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
attr_name_, paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name()));
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type %s, its type is %s.", attr_name_,
paddle::platform::demangle(typeid(T).name()),
paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
......@@ -70,8 +71,9 @@ struct ExtractAttribute<bool> {
try {
attr_value = &boost::get<bool>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type bool, its type is %s.", attr_name_,
paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
......@@ -96,8 +98,9 @@ struct ExtractAttribute<int64_t> {
try {
attr_value = &boost::get<int64_t>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type int64_t, its type is %s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
......@@ -124,8 +127,10 @@ struct ExtractAttribute<std::vector<int64_t>> {
try {
attr_value = &boost::get<std::vector<int64_t>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type int64_t, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type std::vector<int64_t>, its type is "
"%s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
......@@ -150,8 +155,9 @@ struct ExtractAttribute<float> {
try {
attr_value = &boost::get<float>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type float, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type float, its type is %s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
......@@ -173,8 +179,9 @@ class AttrReader {
template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
PADDLE_ENFORCE_NE(attrs_.count(name), 0,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
ExtractAttribute<T> extract_attr(name);
......@@ -192,8 +199,10 @@ class GreaterThanChecker {
public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(value, lower_bound_,
platform::errors::OutOfRange("larger_than check fails."));
PADDLE_ENFORCE_GT(
value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
}
private:
......@@ -205,7 +214,10 @@ class EqualGreaterThanChecker {
public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const {
PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails.");
PADDLE_ENFORCE_GE(
value, lower_bound_,
platform::errors::OutOfRange("Check for attribute valur equal or "
"greater than a certain value failed."));
}
private:
......@@ -231,9 +243,10 @@ class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(const T& val) const {
PADDLE_ENFORCE(container_.find(val) != container_.end(),
"Value %s is not in enum container %s", val,
ContainerDebugString());
PADDLE_ENFORCE_NE(
container_.find(val), container_.end(),
platform::errors::NotFound("Value %s is not in enum container %s.", val,
ContainerDebugString()));
}
private:
......@@ -284,8 +297,11 @@ class TypedAttrChecker {
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE(default_value_setter_.empty(),
"%s can't have more than one default value!", attr_name_);
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
......@@ -308,8 +324,10 @@ class TypedAttrChecker {
auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(),
"Attribute '%s' is required!", attr_name_);
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
......
......@@ -23,7 +23,8 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(),
"Currently, model parallelism is only supported between CPU and CUDA");
platform::errors::Unavailable("Currently, model parallelism is only "
"supported between CPU and CUDA."));
// NOTE(yy): TransDataDevice should wait for computation of input.
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
......
......@@ -133,11 +133,14 @@ bool DataFeed::PickOneFile(std::string* filename) {
}
void DataFeed::CheckInit() {
PADDLE_ENFORCE(finish_init_, "Initialization did not succeed.");
PADDLE_ENFORCE_EQ(finish_init_, true, platform::errors::PreconditionNotMet(
"DataFeed initialization failed."));
}
void DataFeed::CheckSetFileList() {
PADDLE_ENFORCE(finish_set_filelist_, "Set filelist did not succeed.");
PADDLE_ENFORCE_EQ(
finish_set_filelist_, true,
platform::errors::PreconditionNotMet("DataFeed set filelist failed."));
}
void DataFeed::CheckStart() {
......@@ -160,14 +163,18 @@ void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) {
#ifdef PADDLE_WITH_CUDA
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
PADDLE_THROW(platform::errors::Unimplemented(
"Not supported GPU, please compile with option WITH_GPU=ON."));
#endif
}
}
template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
PADDLE_ENFORCE_GT(
queue_size, 0,
platform::errors::InvalidArgument(
"Queue size %d is illegal in PrivateQueueDataFeed.", queue_size));
queue_size_ = queue_size;
queue_ = paddle::framework::MakeChannel<T>();
queue_->SetCapacity(queue_size);
......@@ -418,8 +425,10 @@ void MultiSlotDataFeed::Init(
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
PADDLE_ENFORCE_EQ(
data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in MultiSlotDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
......@@ -668,13 +677,14 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
PADDLE_ENFORCE_NE(
num, 0,
platform::errors::InvalidArgument(
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s.",
str));
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
......@@ -765,8 +775,10 @@ void MultiSlotInMemoryDataFeed::Init(
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
PADDLE_ENFORCE_EQ(
data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in MultiSlotInMemoryDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
......@@ -898,13 +910,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
PADDLE_ENFORCE_NE(
num, 0,
platform::errors::InvalidArgument(
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s.",
str));
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
......@@ -963,13 +976,14 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
PADDLE_ENFORCE_NE(
num, 0,
platform::errors::InvalidArgument(
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s.",
str));
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
......@@ -1085,7 +1099,7 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
PADDLE_ENFORCE_EQ(slot_offset.size(), 2,
platform::errors::InvalidArgument(
"In batch reader, the sparse tensor lod size "
"must be 2, but received %d",
"must be 2, but received %d.",
slot_offset.size()));
const auto& max_size = slot_offset[1];
tmp_offset.reserve(max_size + 1);
......@@ -1137,10 +1151,13 @@ void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (const auto e : use_slots_shape_[i]) {
total_dims *= e;
}
PADDLE_ENFORCE(
total_dims == total_instance,
"The actual data size of slot[%s] doesn't match its declaration",
use_slots_[i].c_str());
PADDLE_ENFORCE_EQ(
total_dims, total_instance,
platform::errors::InvalidArgument(
"The actual data size of slot[%s] doesn't match its declaration. "
"The actual data size of slot is %lld"
", and its declaration is %lld.",
use_slots_[i].c_str(), total_dims, total_instance));
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
......@@ -1162,7 +1179,9 @@ int PrivateInstantDataFeed<T>::Next() {
return -1;
}
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
PADDLE_ENFORCE_EQ(
true, ParseOneMiniBatch(),
platform::errors::InvalidArgument("Fail to parse mini-batch data."));
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
......@@ -1173,8 +1192,10 @@ void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
PADDLE_ENFORCE_EQ(
data_feed_desc.has_multi_slot_desc(), true,
platform::errors::PreconditionNotMet(
"Multi_slot_desc has not been set in PrivateInstantDataFeed."));
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
......@@ -1217,7 +1238,10 @@ template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());
PADDLE_ENFORCE_NE(
fd_, -1, platform::errors::Unavailable(
"Fail to open file: %s in MultiSlotFileInstantDataFeed.",
filename.c_str()));
struct stat sb;
fstat(fd_, &sb);
......@@ -1225,7 +1249,11 @@ bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));
PADDLE_ENFORCE_NE(
buffer_, MAP_FAILED,
platform::errors::Unavailable(
"Memory map failed when create shared memory, error number is %s.",
strerror(errno)));
offset_ = 0;
return true;
......@@ -1257,12 +1285,13 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.");
PADDLE_ENFORCE_NE(
num, 0,
platform::errors::InvalidArgument(
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters."));
offset_ += sizeof(uint16_t);
if (idx != -1) {
......@@ -1304,7 +1333,12 @@ bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
}
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_");
platform::errors::InvalidArgument(
"The batch size id not equal to default batch size, or "
"the offset is not equal to end index."
"The batch size is %d, default batcch size is %d, offset "
"is %d, end index is %d.",
batch_size_, default_batch_size_, offset_, end_));
return true;
}
#endif
......
......@@ -116,7 +116,8 @@ class DataFeed {
virtual ~DataFeed() {}
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
virtual bool CheckFile(const char* filename) {
PADDLE_THROW("This function(CheckFile) is not implemented.");
PADDLE_THROW(platform::errors::Unimplemented(
"This function(CheckFile) is not implemented."));
}
// Set filelist for DataFeed.
// Pay attention that it must init all readers before call this function.
......@@ -179,7 +180,8 @@ class DataFeed {
}
virtual int GetCurBatchSize() { return batch_size_; }
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
PADDLE_THROW(platform::errors::Unimplemented(
"This function(LoadIntoMemory) is not implemented."));
}
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
......@@ -438,14 +440,23 @@ class MultiSlotType {
private:
void CheckType(const std::string& type) const {
PADDLE_ENFORCE((type == "uint64") || (type == "float"),
"There is no this type<%s>.", type);
PADDLE_ENFORCE_EQ((type == "uint64" || type == "float"), true,
platform::errors::InvalidArgument(
"MultiSlotType error, expect type is uint64 or "
"float, but received type is %s.",
type));
}
void CheckFloat() const {
PADDLE_ENFORCE(type_[0] == 'f', "Add %s value to float slot.", type_);
PADDLE_ENFORCE_EQ(
type_[0], 'f',
platform::errors::InvalidArgument(
"MultiSlotType error, add %s value to float slot.", type_));
}
void CheckUint64() const {
PADDLE_ENFORCE(type_[0] == 'u', "Add %s value to uint64 slot.", type_);
PADDLE_ENFORCE_EQ(
type_[0], 'u',
platform::errors::InvalidArgument(
"MultiSlotType error, add %s value to uint64 slot.", type_));
}
std::vector<float> float_feasign_;
std::vector<uint64_t> uint64_feasign_;
......
......@@ -34,8 +34,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file(
const char* filename) {
paddle::framework::DataFeedDesc data_feed_desc;
int file_descriptor = open(filename, O_RDONLY);
PADDLE_ENFORCE_NE(file_descriptor, -1, platform::errors::Unavailable(
"Cannot open file %s.", filename));
PADDLE_ENFORCE_NE(
file_descriptor, -1,
platform::errors::Unavailable(
"Cannot open file %s c load datafeed param from file.", filename));
google::protobuf::io::FileInputStream fileInput(file_descriptor);
google::protobuf::TextFormat::Parse(&fileInput, &data_feed_desc);
close(file_descriptor);
......@@ -45,8 +47,10 @@ paddle::framework::DataFeedDesc load_datafeed_param_from_file(
const std::vector<std::string> load_filelist_from_file(const char* filename) {
std::vector<std::string> filelist;
std::ifstream fin(filename);
PADDLE_ENFORCE_EQ(fin.good(), true, platform::errors::Unavailable(
"Cannot open file %s.", filename));
PADDLE_ENFORCE_EQ(
fin.good(), true,
platform::errors::Unavailable(
"Cannot open file %s when load filelist from file.", filename));
std::string line;
while (getline(fin, line)) {
filelist.push_back(line);
......@@ -196,7 +200,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
}
}
} else {
PADDLE_THROW("Error type in proto file.");
PADDLE_THROW(platform::errors::InvalidArgument(
"Error type in proto file."));
}
} else { // sparse branch
if (slot.type() == "uint64") {
......@@ -218,7 +223,8 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
}
}
} else {
PADDLE_THROW("Error type in proto file.");
PADDLE_THROW(platform::errors::InvalidArgument(
"Error type in proto file."));
}
} // end sparse branch
++index;
......@@ -272,7 +278,10 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set,
file_elem_set->resize(used_slot_num);
for (const auto& file : filelist) {
std::ifstream fin(file.c_str());
PADDLE_ENFORCE(fin.good(), "Can not open %s.", file.c_str());
PADDLE_ENFORCE_EQ(
fin.good(), true,
platform::errors::Unavailable(
"Can not open %s when get element set from file.", file.c_str()));
while (1) {
bool end_flag = false;
int index = 0;
......@@ -298,7 +307,8 @@ void GetElemSetFromFile(std::vector<MultiTypeSet>* file_elem_set,
}
}
} else {
PADDLE_THROW("Error type in proto file.");
PADDLE_THROW(
platform::errors::InvalidArgument("Error type in proto file."));
}
if (slot.is_used()) {
++index;
......
......@@ -45,7 +45,8 @@ inline DataLayout StringToDataLayout(const std::string& str) {
} else if (s == "MKLDNNLAYOUT") {
return DataLayout::kMKLDNN;
} else {
PADDLE_THROW("Unknown storage order string: %s", s);
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown data layout type string: %s.", s));
}
}
......@@ -60,7 +61,8 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) {
case DataLayout::kMKLDNN:
return "MKLDNNLAYOUT";
default:
PADDLE_THROW("unknown DataLayout %d", data_layout);
PADDLE_THROW(platform::errors::InvalidArgument(
"Unknown Data Layout type %d.", data_layout));
}
}
......
......@@ -25,14 +25,17 @@ namespace paddle {
namespace framework {
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to) {
PADDLE_ENFORCE_NE(from, to,
"layout transform should transform different layout");
PADDLE_ENFORCE_NE(
from, to,
platform::errors::InvalidArgument(
"Layout transform should transform between different layout."));
if (from == DataLayout::kNCHW && to == DataLayout::kNHWC) {
return {0, 2, 3, 1};
} else if (from == DataLayout::kNHWC && to == DataLayout::kNCHW) {
return {0, 3, 1, 2};
} else {
PADDLE_THROW("unsupported transform");
PADDLE_THROW(
platform::errors::InvalidArgument("Unsupported layout transform."));
}
}
......@@ -55,7 +58,8 @@ struct CastDataLayout {
auto* context = static_cast<const platform::CPUDeviceContext*>(ctx_);
trans4(*context, in_, out_, axis_);
} else {
PADDLE_THROW("Unsupport CPU <-> GPU!");
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Unsupported data layout cast from CPU to GPU."));
}
}
};
......@@ -66,9 +70,14 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
PADDLE_ENFORCE(
platform::places_are_same_class(kernel_type_for_var.place_,
expected_kernel_type.place_),
"TransDataLayout only support DataLayout transform on same place!");
platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place."));
PADDLE_ENFORCE(arity(in.dims()) == 4, "Input Arity only support 4!");
PADDLE_ENFORCE_EQ(
arity(in.dims()), 4,
platform::errors::InvalidArgument(
"Input dimension arity only can be 4, the input dimension is %s.",
in.dims()));
auto& pool = platform::DeviceContextPool::Instance();
......@@ -108,7 +117,8 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>());
default:
PADDLE_THROW("wrong mkldnn type provided");
PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));
}
}
......@@ -121,8 +131,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN");
platform::errors::InvalidArgument(
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"));
innerTransDataLayoutFromMKLDNN(
in_layout,
......@@ -155,7 +166,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE_NE(in_type, memory::data_type::undef,
"Input tensor type is not supported: %s", in.type());
platform::errors::InvalidArgument(
"Input tensor type (%s) is not supported.",
DataTypeToString(in.type())));
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format =
......
......@@ -38,8 +38,9 @@ inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
case DataLayout::kNCHW:
return MKLDNNMemoryFormat::nchw;
default:
PADDLE_THROW("Fail to convert layout %s to MKLDNN format",
DataLayoutToString(layout));
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert layout %s to MKLDNN format.",
DataLayoutToString(layout)));
}
}
......@@ -50,7 +51,8 @@ inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW;
default:
PADDLE_THROW("Fail to convert MKLDNN format to paddle layout");
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert MKLDNN format to paddle layout."));
}
}
......
......@@ -117,7 +117,7 @@ TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string msg = "larger_than check fail";
std::string msg = "OutOfRangeError";
std::string err_msg = err.what();
ASSERT_TRUE(err_msg.find(msg) != std::string::npos);
}
......@@ -151,7 +151,7 @@ TEST(OpRegistry, CustomChecker) {
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet& err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
std::string msg = "InvalidArgumentError";
std::string err_msg = err.what();
ASSERT_TRUE(err_msg.find(msg) != std::string::npos);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册