提交 6b474d08 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

Dev switch error proto with cfg error proto (#3858)

* replace ErrorProto with cfg::ErrorProto

* fix macro name

* optimize error

* optimize error

* fix code style

* Organize the code

* fix code style

* fix code style

* copy cfg head file
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Former-commit-id: dc8eb372
上级 fb95f451
......@@ -25,8 +25,11 @@ execute_process(
include_directories(${CFG_INCLUDE_DIR})
list(APPEND ONEFLOW_INCLUDE_SRC_DIRS ${CFG_INCLUDE_DIR})
function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR)
list(APPEND ALL_CFG_CONVERT_PROTO
oneflow/core/common/error.proto
oneflow/core/vm/instruction.proto
oneflow/core/eager/eager_symbol.proto
oneflow/core/job/job_conf.proto
......
......@@ -380,6 +380,7 @@ foreach(of_include_src_dir ${ONEFLOW_INCLUDE_SRC_DIRS})
endforeach()
copy_files("${PROTO_HDRS}" "${PROJECT_BINARY_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy)
copy_files("${CFG_HRCS}" "${PROJECT_BINARY_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy)
set(OF_CORE_HDRS)
list(APPEND of_core_dir_name_list "common" "device" "framework" "kernel/util" "persistence")
......
......@@ -35,151 +35,151 @@ Error&& Error::AddStackFrame(const std::string& location, const std::string& fun
return std::move(*this);
}
Error::operator std::string() const { return PbMessage2TxtString(*error_proto_); }
Error::operator std::string() const { return error_proto_->DebugString(); }
Error Error::Ok() { return std::make_shared<ErrorProto>(); }
Error Error::Ok() { return std::make_shared<cfg::ErrorProto>(); }
Error Error::ProtoParseFailedError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_proto_parse_failed_error();
return error;
}
Error Error::JobSetEmptyError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_set_empty_error();
return error;
}
Error Error::DeviceTagNotFoundError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_device_tag_not_found_error();
return error;
}
Error Error::JobNameExistError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_name_exist_error();
return error;
}
Error Error::JobNameEmptyError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_name_empty_error();
return error;
}
Error Error::JobNameNotEqualError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_name_not_equal_error();
return error;
}
Error Error::NoJobBuildAndInferCtxError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_no_job_build_and_infer_ctx_error();
return error;
}
Error Error::JobConfFrozenError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_conf_frozen_error();
return error;
}
Error Error::JobConfNotSetError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_conf_not_set_error();
return error;
}
Error Error::JobConfRepeatedSetError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_conf_repeated_set_error();
return error;
}
Error Error::JobTypeNotSetError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_type_not_set_error();
return error;
}
Error Error::LogicalBlobNameNotExistError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_logical_blob_name_not_exist_error();
return error;
}
Error Error::LogicalBlobNameExistError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_logical_blob_name_exist_error();
return error;
}
Error Error::LogicalBlobNameInvalidError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_logical_blob_name_invalid_error();
return error;
}
Error Error::OpNameExistError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_op_name_exist_error();
return error;
}
Error Error::OpConfDeviceTagNoSetError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_op_conf_device_tag_no_set_error();
return error;
}
Error Error::PlacementError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_placement_error();
return error;
}
Error Error::BlobSplitAxisInferError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_blob_split_axis_infer_error();
return error;
}
Error Error::UnknownJobBuildAndInferError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_unknown_job_build_and_infer_error();
return error;
}
Error Error::CheckFailedError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_check_failed_error();
return error;
}
Error Error::Todo() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_todo_error();
return error;
}
Error Error::Unimplemented() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_unimplemented_error();
return error;
}
Error Error::BoxingNotSupportedError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_boxing_not_supported_error();
return error;
}
Error Error::OpKernelNotFoundError(const std::string& error_summary,
const std::vector<std::string>& error_msgs) {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->set_error_summary(error_summary);
auto* op_kernel_not_found_error = error->mutable_op_kernel_not_found_error();
for (const auto& msg : error_msgs) {
......@@ -190,7 +190,7 @@ Error Error::OpKernelNotFoundError(const std::string& error_summary,
Error Error::MultipleOpKernelsMatchedError(const std::string& error_summary,
const std::vector<std::string>& error_msgs) {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->set_error_summary(error_summary);
auto* multiple_op_kernels_matched_error = error->mutable_multiple_op_kernels_matched_error();
for (const auto& msg : error_msgs) {
......@@ -201,7 +201,7 @@ Error Error::MultipleOpKernelsMatchedError(const std::string& error_summary,
Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,
uint64_t available, const std::string& device_tag) {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
auto* memory_zone_out_of_memory_error = error->mutable_memory_zone_out_of_memory_error();
memory_zone_out_of_memory_error->add_machine_id(std::to_string(machine_id));
memory_zone_out_of_memory_error->add_mem_zone_id(std::to_string(mem_zone_id));
......@@ -212,20 +212,20 @@ Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id,
}
Error Error::LossBlobNotFoundError(const std::string& error_summary) {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_loss_blob_not_found_error();
error->set_error_summary(error_summary);
return error;
}
Error Error::RwMutexedObjectNotFoundError() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_rw_mutexed_object_not_found_error();
return error;
}
Error Error::GradientFunctionNotFound() {
auto error = std::make_shared<ErrorProto>();
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_gradient_function_not_found_error();
return error;
}
......
......@@ -18,13 +18,13 @@ limitations under the License.
#include <sstream>
#include <vector>
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/error.cfg.h"
namespace oneflow {
class Error final {
public:
Error(const std::shared_ptr<ErrorProto>& error_proto) : error_proto_(error_proto) {}
Error(const std::shared_ptr<cfg::ErrorProto>& error_proto) : error_proto_(error_proto) {}
Error(const Error&) = default;
~Error() = default;
......@@ -69,14 +69,14 @@ class Error final {
// gradient
static Error GradientFunctionNotFound();
std::shared_ptr<ErrorProto> error_proto() const { return error_proto_; }
const ErrorProto* operator->() const { return error_proto_.get(); }
ErrorProto* operator->() { return error_proto_.get(); }
std::shared_ptr<cfg::ErrorProto> error_proto() const { return error_proto_; }
const cfg::ErrorProto* operator->() const { return error_proto_.get(); }
cfg::ErrorProto* operator->() { return error_proto_.get(); }
operator std::string() const;
void Assign(const Error& other) { error_proto_ = other.error_proto_; }
private:
std::shared_ptr<ErrorProto> error_proto_;
std::shared_ptr<cfg::ErrorProto> error_proto_;
};
// r-value reference is used to supporting expressions like `Error() << "invalid value"`
......
......@@ -36,7 +36,7 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || std::is
Maybe(const T& data) : data_or_error_(std::make_shared<T>(data)) {}
Maybe(const Error& error) : data_or_error_(error.error_proto()) {}
Maybe(const std::shared_ptr<T>& data) : data_or_error_(data) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : data_or_error_(error) {}
Maybe(const std::shared_ptr<cfg::ErrorProto>& error) : data_or_error_(error) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
......@@ -45,35 +45,43 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || std::is
std::shared_ptr<T> Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return data_or_error_.template Get<T>();
}
std::shared_ptr<ErrorProto> error() const { return data_or_error_.template Get<ErrorProto>(); }
std::string GetSerializedError() const {
std::string str;
google::protobuf::TextFormat::PrintToString(*error(), &str);
return str;
std::shared_ptr<cfg::ErrorProto> error() const {
return data_or_error_.template Get<cfg::ErrorProto>();
}
std::string GetSerializedError() const { return this->error()->DebugString(); }
template<typename Type = T>
Type GetDataAndSerializedErrorProto(std::string* error_str, const Type& default_for_error) const {
static_assert(std::is_same<T, Type>::value, "error type for argument 1");
if (IsOk()) {
google::protobuf::TextFormat::PrintToString(ErrorProto(), error_str);
*error_str = cfg::ErrorProto().DebugString();
return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} else {
google::protobuf::TextFormat::PrintToString(*error(), error_str);
*error_str = this->error()->DebugString();
return default_for_error;
}
}
template<typename Type = T>
std::pair<Type, std::shared_ptr<cfg::ErrorProto>> GetDataAndErrorProto(
const Type& default_for_error) const {
if (IsOk()) {
return std::make_pair(*Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::make_shared<cfg::ErrorProto>(cfg::ErrorProto()));
} else {
return std::make_pair(default_for_error, error());
}
}
private:
EitherPtr<T, ErrorProto> data_or_error_;
EitherPtr<T, cfg::ErrorProto> data_or_error_;
};
template<typename T>
class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> final {
public:
Maybe(const Error& error) : error_or_plain_(error.error_proto()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : error_or_plain_(error) { CheckError(); }
Maybe(const std::shared_ptr<cfg::ErrorProto>& error) : error_or_plain_(error) { CheckError(); }
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
......@@ -82,28 +90,36 @@ class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> fina
bool IsOk() const { return error_or_plain_.IsPlain(); }
void Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {}
std::shared_ptr<ErrorProto> error() const { return error_or_plain_.shared_ptr(); }
std::shared_ptr<cfg::ErrorProto> error() const { return error_or_plain_.shared_ptr(); }
std::string GetSerializedError() const {
CHECK(!IsOk());
std::string str;
google::protobuf::TextFormat::PrintToString(*error(), &str);
return str;
return this->error()->DebugString();
}
void GetDataAndSerializedErrorProto(std::string* error_str) const {
if (IsOk()) {
google::protobuf::TextFormat::PrintToString(ErrorProto(), error_str);
*error_str = cfg::ErrorProto().DebugString();
} else {
*error_str = this->error()->DebugString();
}
}
std::shared_ptr<cfg::ErrorProto> GetDataAndErrorProto() const {
if (IsOk()) {
return std::make_shared<cfg::ErrorProto>(cfg::ErrorProto());
} else {
google::protobuf::TextFormat::PrintToString(*error(), error_str);
return error();
}
}
private:
Maybe() : error_or_plain_(nullptr) {}
void CheckError() const { CHECK_NE(error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); }
void CheckError() const {
CHECK_NE(this->error()->error_type_case(), cfg::ErrorProto::ERROR_TYPE_NOT_SET);
}
SharedOrPlain<ErrorProto, void*> error_or_plain_;
SharedOrPlain<cfg::ErrorProto, void*> error_or_plain_;
};
template<typename T>
......@@ -111,7 +127,7 @@ class Maybe<T, typename std::enable_if<std::is_scalar<T>::value>::type> final {
public:
Maybe(T data) : error_or_plain_(data) {}
Maybe(const Error& error) : error_or_plain_(error.error_proto()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : error_or_plain_(error) { CheckError(); }
Maybe(const std::shared_ptr<cfg::ErrorProto>& error) : error_or_plain_(error) { CheckError(); }
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
......@@ -120,29 +136,39 @@ class Maybe<T, typename std::enable_if<std::is_scalar<T>::value>::type> final {
T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return error_or_plain_.plain_data();
}
std::shared_ptr<ErrorProto> error() const { return error_or_plain_.shared_ptr(); }
std::shared_ptr<cfg::ErrorProto> error() const { return error_or_plain_.shared_ptr(); }
std::string GetSerializedError() const {
CHECK(!IsOk());
std::string str;
google::protobuf::TextFormat::PrintToString(*error(), &str);
return str;
return this->error()->DebugString();
}
T GetDataAndSerializedErrorProto(std::string* error_str, const T& default_for_error) const {
if (IsOk()) {
google::protobuf::TextFormat::PrintToString(ErrorProto(), error_str);
*error_str = cfg::ErrorProto().DebugString();
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} else {
google::protobuf::TextFormat::PrintToString(*error(), error_str);
*error_str = this->error()->DebugString();
return default_for_error;
}
}
std::pair<T, std::shared_ptr<cfg::ErrorProto>> GetDataAndErrorProto(
const T& default_for_error) const {
if (IsOk()) {
return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::make_shared<cfg::ErrorProto>(cfg::ErrorProto()));
} else {
return std::make_pair(default_for_error, error());
}
}
private:
void CheckError() const { CHECK_NE(error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); }
void CheckError() const {
CHECK_NE(this->error()->error_type_case(), cfg::ErrorProto::ERROR_TYPE_NOT_SET);
}
SharedOrPlain<ErrorProto, T> error_or_plain_;
SharedOrPlain<cfg::ErrorProto, T> error_or_plain_;
};
template<typename T>
......@@ -155,7 +181,7 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || std::is
public:
Maybe(T data) : maybe_ptr_(&data) {}
Maybe(const Error& error) : maybe_ptr_(error) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : maybe_ptr_(error) {}
Maybe(const std::shared_ptr<cfg::ErrorProto>& error) : maybe_ptr_(error) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
......@@ -164,7 +190,7 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || std::is
T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return *maybe_ptr_.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
}
std::shared_ptr<ErrorProto> error() const { return maybe_ptr_.error(); }
std::shared_ptr<cfg::ErrorProto> error() const { return maybe_ptr_.error(); }
std::string GetSerializedError() const { return maybe_ptr_.GetSerializedError(); }
......@@ -172,6 +198,10 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || std::is
return *maybe_ptr_.GetDataAndSerializedErrorProto(error_str, static_cast<PtrT>(nullptr));
}
std::pair<T, std::shared_ptr<cfg::ErrorProto>> GetDataAndErrorProto() const {
return *maybe_ptr_.GetDataAndErrorProto(static_cast<PtrT>(nullptr));
}
private:
Maybe<PtrT> maybe_ptr_;
};
......
......@@ -97,7 +97,7 @@ bool SubTskGphBuilderUtil::BlobHasDynamicShape(const BlobDesc& blob_desc) {
return blob_desc.is_dynamic();
}
bool SubTskGphBuilderUtil::IsErrorBoxingNotSupported(const ErrorProto& error) {
bool SubTskGphBuilderUtil::IsErrorBoxingNotSupported(const cfg::ErrorProto& error) {
return error.has_boxing_not_supported_error();
}
......
......@@ -44,7 +44,7 @@ struct SubTskGphBuilderUtil {
static bool IsBoxingB2B(const SbpParallel& src, const SbpParallel& dst);
static bool IsBoxingB2S(const SbpParallel& src, const SbpParallel& dst);
static bool BlobHasDynamicShape(const BlobDesc& blob_desc);
static bool IsErrorBoxingNotSupported(const ErrorProto& error);
static bool IsErrorBoxingNotSupported(const cfg::ErrorProto& error);
static int64_t GetDistance(const TaskNode* src, const TaskNode* dst);
template<typename NodeType>
static int64_t FindNearestNodeIndex(const std::vector<NodeType*> from_nodes,
......
......@@ -72,7 +72,7 @@ REGISTER_USER_OP("batch_gather")
.PartialSum(user_op::OpArg("out", 0))
.Build();
} else {
std::shared_ptr<ErrorProto> err;
std::shared_ptr<cfg::ErrorProto> err;
err->set_msg("BatchGatherOp: indices_num_axes equals " + std::to_string(indices_num_axes)
+ " (should be bigger than 1).");
err->mutable_check_failed_error();
......
#ifndef CFG_PYBIND_REGISTRY_H_
#define CFG_PYBIND_REGISTRY_H_
#ifndef ONEFLOW_CFG_PYBIND_REGISTRY_H_
#define ONEFLOW_CFG_PYBIND_REGISTRY_H_
#include <pybind11/pybind11.h>
#include <map>
#include <vector>
......@@ -39,4 +39,4 @@ class Pybind11ModuleRegistry {
} \
static void OneflowCfgPythonModule##__LINE__(pybind11::module& m)
#endif // CFG_PYBIND_REGISTRY_H_
#endif // ONEFLOW_CFG_PYBIND_REGISTRY_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册