未验证 提交 06be620f 编写于 作者: L Li Xinqi 提交者: GitHub

rename UserOpAttrVal to AttrValue (#3752)

Co-authored-by: Noneflow-bot <69100618+oneflow-bot@users.noreply.github.com>
上级 17fc5509
......@@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_FRAMEWORK_USER_OP_ATTR_H_
#define ONEFLOW_CORE_FRAMEWORK_USER_OP_ATTR_H_
#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_
#define ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_
#include "oneflow/core/framework/user_op_attr.pb.h"
#include "oneflow/core/common/util.h"
......@@ -26,31 +26,31 @@ namespace oneflow {
namespace user_op {
// SEQ
#define BASIC_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_int32, int32_t, UserOpAttrType::kAtInt32) \
OF_PP_MAKE_TUPLE_SEQ(at_int64, int64_t, UserOpAttrType::kAtInt64) \
OF_PP_MAKE_TUPLE_SEQ(at_bool, bool, UserOpAttrType::kAtBool) \
OF_PP_MAKE_TUPLE_SEQ(at_float, float, UserOpAttrType::kAtFloat) \
OF_PP_MAKE_TUPLE_SEQ(at_double, double, UserOpAttrType::kAtDouble) \
OF_PP_MAKE_TUPLE_SEQ(at_string, std::string, UserOpAttrType::kAtString)
#define BASIC_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_int32, int32_t, AttrType::kAtInt32) \
OF_PP_MAKE_TUPLE_SEQ(at_int64, int64_t, AttrType::kAtInt64) \
OF_PP_MAKE_TUPLE_SEQ(at_bool, bool, AttrType::kAtBool) \
OF_PP_MAKE_TUPLE_SEQ(at_float, float, AttrType::kAtFloat) \
OF_PP_MAKE_TUPLE_SEQ(at_double, double, AttrType::kAtDouble) \
OF_PP_MAKE_TUPLE_SEQ(at_string, std::string, AttrType::kAtString)
#define ENUM_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_data_type, DataType, UserOpAttrType::kAtDataType)
#define ENUM_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_data_type, DataType, AttrType::kAtDataType)
#define MESSAGE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_shape, Shape, UserOpAttrType::kAtShape)
#define MESSAGE_ATTR_SEQ OF_PP_MAKE_TUPLE_SEQ(at_shape, Shape, AttrType::kAtShape)
#define LIST_BASIC_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_int32, std::vector<int32_t>, UserOpAttrType::kAtListInt32) \
OF_PP_MAKE_TUPLE_SEQ(at_list_int64, std::vector<int64_t>, UserOpAttrType::kAtListInt64) \
OF_PP_MAKE_TUPLE_SEQ(at_list_float, std::vector<float>, UserOpAttrType::kAtListFloat)
#define LIST_BASIC_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_int32, std::vector<int32_t>, AttrType::kAtListInt32) \
OF_PP_MAKE_TUPLE_SEQ(at_list_int64, std::vector<int64_t>, AttrType::kAtListInt64) \
OF_PP_MAKE_TUPLE_SEQ(at_list_float, std::vector<float>, AttrType::kAtListFloat)
#define LIST_ENUM_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_data_type, std::vector<DataType>, UserOpAttrType::kAtListDataType)
OF_PP_MAKE_TUPLE_SEQ(at_list_data_type, std::vector<DataType>, AttrType::kAtListDataType)
#define LIST_MESSAGE_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_shape, std::vector<Shape>, UserOpAttrType::kAtListShape)
OF_PP_MAKE_TUPLE_SEQ(at_list_shape, std::vector<Shape>, AttrType::kAtListShape)
#define LIST_STRING_ATTR_SEQ \
OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector<std::string>, UserOpAttrType::kAtListString)
OF_PP_MAKE_TUPLE_SEQ(at_list_string, std::vector<std::string>, AttrType::kAtListString)
#define ATTR_SEQ \
BASIC_ATTR_SEQ \
......@@ -66,15 +66,15 @@ namespace user_op {
template<typename T>
struct GetAttrType;
template<UserOpAttrType AttrT>
template<AttrType AttrT>
struct GetCppType;
#define SPECIALIZE_GET_ATTR_TYPE(field, type_cpp, type_proto) \
template<> \
struct GetAttrType<type_cpp> : std::integral_constant<UserOpAttrType, type_proto> {}; \
template<> \
struct GetCppType<type_proto> { \
typedef type_cpp type; \
#define SPECIALIZE_GET_ATTR_TYPE(field, type_cpp, type_proto) \
template<> \
struct GetAttrType<type_cpp> : std::integral_constant<AttrType, type_proto> {}; \
template<> \
struct GetCppType<type_proto> { \
typedef type_cpp type; \
};
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_ATTR_TYPE, ATTR_SEQ);
#undef SPECIALIZE_GET_ATTR_TYPE
......@@ -83,4 +83,4 @@ OF_PP_FOR_EACH_TUPLE(SPECIALIZE_GET_ATTR_TYPE, ATTR_SEQ);
} // namespace oneflow
#endif // ONEFLOW_CORE_FRAMEWORK_USER_OP_ATTR_H_
#endif // ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_
......@@ -23,15 +23,15 @@ namespace oneflow {
namespace user_op {
// Basic and Enum Attr
#define BASIC_AND_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \
template<> \
cpp_type AttrValAccessor<cpp_type>::Attr(const UserOpAttrVal& val) { \
CHECK(val.has_##field()); \
return val.field(); \
} \
template<> \
void AttrValAccessor<cpp_type>::Attr(const cpp_type& cpp_val, UserOpAttrVal* attr_val) { \
attr_val->set_##field(cpp_val); \
#define BASIC_AND_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \
template<> \
cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) { \
CHECK(val.has_##field()); \
return val.field(); \
} \
template<> \
void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \
attr_val->set_##field(cpp_val); \
}
#define BASIC_AND_ENUM_ATTR_SEQ \
......@@ -45,22 +45,22 @@ OF_PP_FOR_EACH_TUPLE(BASIC_AND_ENUM_ATTR_SEQ_ENTRY, BASIC_AND_ENUM_ATTR_SEQ)
// Customized Message Attr
template<>
Shape AttrValAccessor<Shape>::Attr(const UserOpAttrVal& val) {
Shape AttrValueAccessor<Shape>::Attr(const AttrValue& val) {
return Shape(val.at_shape());
}
template<>
void AttrValAccessor<Shape>::Attr(const Shape& cpp_val, UserOpAttrVal* attr_val) {
void AttrValueAccessor<Shape>::Attr(const Shape& cpp_val, AttrValue* attr_val) {
cpp_val.ToProto(attr_val->mutable_at_shape());
}
// List of Basic Attr
#define LIST_BASIC_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \
template<> \
cpp_type AttrValAccessor<cpp_type>::Attr(const UserOpAttrVal& val) { \
cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) { \
return PbRf2StdVec<cpp_type::value_type>(val.field().val()); \
} \
template<> \
void AttrValAccessor<cpp_type>::Attr(const cpp_type& cpp_val, UserOpAttrVal* attr_val) { \
void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \
*(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf<cpp_type::value_type>(cpp_val); \
}
......@@ -71,7 +71,7 @@ OF_PP_FOR_EACH_TUPLE(LIST_BASIC_ATTR_SEQ_ENTRY, LIST_BASIC_ATTR_SEQ)
// List of Enum Attr
#define LIST_ENUM_ATTR_SEQ_ENTRY(field, cpp_type, attr_type) \
template<> \
cpp_type AttrValAccessor<cpp_type>::Attr(const UserOpAttrVal& val) { \
cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) { \
std::vector<cpp_type::value_type> ret; \
ret.reserve(val.field().val_size()); \
for (const auto& value : val.field().val()) { \
......@@ -80,7 +80,7 @@ OF_PP_FOR_EACH_TUPLE(LIST_BASIC_ATTR_SEQ_ENTRY, LIST_BASIC_ATTR_SEQ)
return ret; \
} \
template<> \
void AttrValAccessor<cpp_type>::Attr(const cpp_type& cpp_val, UserOpAttrVal* attr_val) { \
void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \
using proto_type = std::remove_reference_t<decltype(attr_val->field().val())>::value_type; \
std::vector<proto_type> vec; \
vec.reserve(cpp_val.size()); \
......@@ -94,15 +94,15 @@ OF_PP_FOR_EACH_TUPLE(LIST_ENUM_ATTR_SEQ_ENTRY, LIST_ENUM_ATTR_SEQ)
// List of Customized Message Attr
template<>
std::vector<Shape> AttrValAccessor<std::vector<Shape>>::Attr(const UserOpAttrVal& val) {
std::vector<Shape> AttrValueAccessor<std::vector<Shape>>::Attr(const AttrValue& val) {
std::vector<Shape> ret;
ret.reserve(val.at_list_shape().val_size());
for (const auto& value : val.at_list_shape().val()) { ret.emplace_back(value); }
return ret;
}
template<>
void AttrValAccessor<std::vector<Shape>>::Attr(const std::vector<Shape>& cpp_val,
UserOpAttrVal* attr_val) {
void AttrValueAccessor<std::vector<Shape>>::Attr(const std::vector<Shape>& cpp_val,
AttrValue* attr_val) {
if (attr_val->at_list_shape().val_size() > 0) { attr_val->mutable_at_list_shape()->clear_val(); }
FOR_RANGE(int32_t, i, 0, cpp_val.size()) {
cpp_val.at(i).ToProto(attr_val->mutable_at_list_shape()->add_val());
......@@ -111,12 +111,12 @@ void AttrValAccessor<std::vector<Shape>>::Attr(const std::vector<Shape>& cpp_val
// List of String Attr
template<>
std::vector<std::string> AttrValAccessor<std::vector<std::string>>::Attr(const UserOpAttrVal& val) {
std::vector<std::string> AttrValueAccessor<std::vector<std::string>>::Attr(const AttrValue& val) {
return PbRpf2StdVec<std::string>(val.at_list_string().val());
}
template<>
void AttrValAccessor<std::vector<std::string>>::Attr(const std::vector<std::string>& cpp_val,
UserOpAttrVal* attr_val) {
void AttrValueAccessor<std::vector<std::string>>::Attr(const std::vector<std::string>& cpp_val,
AttrValue* attr_val) {
*(attr_val->mutable_at_list_string()->mutable_val()) = StdVec2PbRpf<std::string>(cpp_val);
}
......
......@@ -16,16 +16,16 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_
#define ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_
#include "oneflow/core/framework/user_op_attr.h"
#include "oneflow/core/framework/attr_value.h"
namespace oneflow {
namespace user_op {
template<typename T>
struct AttrValAccessor final {
static T Attr(const UserOpAttrVal&);
static void Attr(const T&, UserOpAttrVal*);
struct AttrValueAccessor final {
static T Attr(const AttrValue&);
static void Attr(const T&, AttrValue*);
};
} // namespace user_op
......
......@@ -27,7 +27,7 @@ ConfigDef* MutGlobalConfigDef() {
}
template<ConfigDefType config_def_type>
UserOpAttrVal* AddConfigFlagDef(const std::string& name, const std::string& description) {
AttrValue* AddConfigFlagDef(const std::string& name, const std::string& description) {
auto* name2flag_def = MutGlobalConfigDef<config_def_type>()->mutable_flag_name2flag_def();
CHECK(name2flag_def->find(name) == name2flag_def->end());
auto* flag_def = &(*name2flag_def)[name];
......
......@@ -12,7 +12,7 @@ enum ConfigDefType {
message ConfigFlagDef {
required string name = 1;
required string description = 2;
required UserOpAttrVal default_val = 3;
required AttrValue default_val = 3;
}
message ConfigDef {
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/framework/user_op_def.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/framework/user_op_attr.h"
#include "oneflow/core/framework/attr_value.h"
#include "oneflow/core/framework/user_op_def.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/attr_value_accessor.h"
......
......@@ -4,7 +4,7 @@ package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
enum UserOpAttrType {
enum AttrType {
kAtInt32 = 1;
kAtInt64 = 2;
kAtBool = 3;
......@@ -21,7 +21,7 @@ enum UserOpAttrType {
kAtListString = 14;
}
message UserOpAttrVal {
message AttrValue {
message ListInt32 {
repeated int32 val = 1;
}
......@@ -37,7 +37,7 @@ message UserOpAttrVal {
message ListShape {
repeated ShapeProto val = 1;
}
// order and naming convention of the oneof field must be consistent with the enum UserOpAttrType
// order and naming convention of the oneof field must be consistent with the enum AttrType
message ListString {
repeated string val = 1;
}
......
......@@ -27,15 +27,15 @@ namespace user_op {
UserOpConfWrapper::UserOpConfWrapper(const OperatorConf& op_conf) : op_conf_(op_conf) {
CHECK(op_conf_.has_user_conf());
for (const auto& kv : op_conf_.user_conf().attr()) {
UserOpAttrVal::ValueCase value_case = kv.second.value_case();
AttrValue::ValueCase value_case = kv.second.value_case();
switch (value_case) {
#define CASE_ENTRY(field, cpp_type, attr_type) \
/* UserOpAttrVal::ValueCase has the same order and naming convention as UserOpAttrType */ \
case (static_cast<UserOpAttrVal::ValueCase>(attr_type)): \
CHECK(attrs_ \
.emplace(kv.first, std::make_shared<TypedAttrVal<cpp_type>>( \
AttrValAccessor<cpp_type>::Attr(kv.second))) \
.second); \
#define CASE_ENTRY(field, cpp_type, attr_type) \
/* AttrValue::ValueCase has the same order and naming convention as AttrType */ \
case (static_cast<AttrValue::ValueCase>(attr_type)): \
CHECK(attrs_ \
.emplace(kv.first, std::make_shared<TypedAttrVal<cpp_type>>( \
AttrValueAccessor<cpp_type>::Attr(kv.second))) \
.second); \
break;
OF_PP_FOR_EACH_TUPLE(CASE_ENTRY, ATTR_SEQ)
#undef CASE_ENTRY
......@@ -98,15 +98,15 @@ int32_t UserOpConfWrapper::output_size(const std::string& arg_name) const {
return std::dynamic_pointer_cast<TypedAttrVal<cpp_type>>(it->second)->val(); \
} else { \
LOG(FATAL) << "Cannot find the attr: " << attr_name \
<< " with UserOpAttrType: " << static_cast<int32_t>(attr_type); \
<< " with AttrType: " << static_cast<int32_t>(attr_type); \
} \
} \
\
template<> \
UserOpConfWrapperBuilder& UserOpConfWrapperBuilder::Attr<cpp_type>(const std::string& attr_name, \
const cpp_type& val) { \
UserOpAttrVal attr_val; \
AttrValAccessor<cpp_type>::Attr(val, &attr_val); \
AttrValue attr_val; \
AttrValueAccessor<cpp_type>::Attr(val, &attr_val); \
attr_.emplace(attr_name, attr_val); \
return *this; \
}
......@@ -323,7 +323,7 @@ Maybe<void> AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, OperatorCo
<< " op_name: " << op_conf->name() << " op_type_name: " << user_conf->op_type_name()
<< " attr_name: " << attr.name()
<< " has different attr type in OpDef and OpConf, it should be with type: "
<< UserOpAttrType_Name(attr.type());
<< AttrType_Name(attr.type());
}
return Maybe<void>::Ok();
}
......@@ -344,8 +344,7 @@ Maybe<void> AddUserOpConfOutputDefaultArg(const UserOpDef& op_def, OperatorConf*
return Maybe<void>::Ok();
}
Maybe<long long> GetUserOpAttrTypeImpl(const std::string& op_type_name,
const std::string& attr_name) {
Maybe<long long> GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name) {
const user_op::OpRegistryResult* val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);
CHECK_OR_RETURN(val) << " Cannot find op " << op_type_name;
......
......@@ -165,7 +165,7 @@ class UserOpConfWrapperBuilder final {
std::string op_type_name_;
HashMap<std::string, std::vector<std::string>> input_;
HashMap<std::string, std::vector<std::string>> output_;
HashMap<std::string, UserOpAttrVal> attr_;
HashMap<std::string, AttrValue> attr_;
OptInt64 scope_symbol_id_;
};
......@@ -190,8 +190,7 @@ class BackwardOpConfContext final {
} // namespace user_op
Maybe<long long> GetUserOpAttrTypeImpl(const std::string& op_type_name,
const std::string& attr_name);
Maybe<long long> GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name);
Maybe<OperatorConf> CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf);
} // namespace oneflow
......
......@@ -10,5 +10,5 @@ message UserOpConf {
required string op_type_name = 1;
map<string, ListString> input = 2;
map<string, ListString> output = 3;
map<string, UserOpAttrVal> attr = 4;
map<string, AttrValue> attr = 4;
}
......@@ -65,7 +65,7 @@ const UserOpDef::ArgDef* UserOpDefWrapper::GetArgPointer(const std::string& name
return nullptr;
}
UserOpAttrType UserOpDefWrapper::GetAttrType(const std::string& name) const {
AttrType UserOpDefWrapper::GetAttrType(const std::string& name) const {
return attrs_.at(name)->type();
}
......@@ -77,9 +77,9 @@ bool UserOpDefWrapper::AttrHasDefaultVal(const std::string& name) const {
template<> \
cpp_type UserOpDefWrapper::GetAttrDefaultVal<cpp_type>(const std::string& name) const { \
CHECK(AttrHasDefaultVal(name)); \
const UserOpAttrVal& default_val = attrs_.at(name)->default_val(); \
const AttrValue& default_val = attrs_.at(name)->default_val(); \
CHECK_EQ(static_cast<int>(attr_type), default_val.value_case()); \
return AttrValAccessor<cpp_type>::Attr(default_val); \
return AttrValueAccessor<cpp_type>::Attr(default_val); \
}
OF_PP_FOR_EACH_TUPLE(ATTR_TYPE_SPECIALIZATION, ATTR_SEQ)
......
......@@ -39,7 +39,7 @@ class UserOpDefWrapper final {
bool IsArgOptional(const std::string&) const;
std::pair<int32_t, bool> ArgNumAndIsMin(const std::string&) const;
UserOpAttrType GetAttrType(const std::string&) const;
AttrType GetAttrType(const std::string&) const;
bool AttrHasDefaultVal(const std::string&) const;
template<typename T>
T GetAttrDefaultVal(const std::string&) const;
......
......@@ -17,8 +17,8 @@ message UserOpDef {
message AttrDef {
required string name = 1;
required UserOpAttrType type = 2;
optional UserOpAttrVal default_val = 3;
required AttrType type = 2;
optional AttrValue default_val = 3;
}
repeated AttrDef attr = 4;
}
......@@ -16,7 +16,7 @@ limitations under the License.
#include "oneflow/core/framework/user_op_registry.h"
#include "oneflow/core/framework/infer_util.h"
#include "oneflow/core/framework/user_op_attr.h"
#include "oneflow/core/framework/attr_value.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/sbp_context.h"
#include "oneflow/core/framework/batch_axis_context.h"
......@@ -87,7 +87,7 @@ OpRegistry& OpRegistry::SetOutputBufferNum(int32_t num) {
return *this;
}
OpRegistry& OpRegistry::Attr(const std::string& name, UserOpAttrType type) {
OpRegistry& OpRegistry::Attr(const std::string& name, AttrType type) {
CHECK(InsertIfNotExists(name, &unique_names_));
UserOpDef::AttrDef attr_def;
attr_def.set_name(name);
......@@ -98,7 +98,7 @@ OpRegistry& OpRegistry::Attr(const std::string& name, UserOpAttrType type) {
namespace {
void AddAttrWithDefault(OpRegistryResult* result, const std::string& name, UserOpAttrType type,
void AddAttrWithDefault(OpRegistryResult* result, const std::string& name, AttrType type,
std::function<void(UserOpDef::AttrDef*)> handler) {
UserOpDef::AttrDef attr_def;
attr_def.set_name(name);
......@@ -111,18 +111,18 @@ void AddAttrWithDefault(OpRegistryResult* result, const std::string& name, UserO
#define ATTR_MEMBER_FUNC(field, cpp_type, attr_type) \
template<> \
OpRegistry& OpRegistry::Attr<cpp_type>(const std::string& name, UserOpAttrType type, \
OpRegistry& OpRegistry::Attr<cpp_type>(const std::string& name, AttrType type, \
const cpp_type& default_val) { \
CHECK_EQ(type, attr_type); \
return DefaultedAttr(name, type, [default_val](UserOpDef::AttrDef* attr_def) { \
AttrValAccessor<cpp_type>::Attr(default_val, attr_def->mutable_default_val()); \
AttrValueAccessor<cpp_type>::Attr(default_val, attr_def->mutable_default_val()); \
}); \
} \
template<> \
OpRegistry& OpRegistry::Attr<cpp_type>(const std::string& name, const cpp_type& default_val) { \
return DefaultedAttr( \
name, GetAttrType<cpp_type>::value, [default_val](UserOpDef::AttrDef* attr_def) { \
AttrValAccessor<cpp_type>::Attr(default_val, attr_def->mutable_default_val()); \
AttrValueAccessor<cpp_type>::Attr(default_val, attr_def->mutable_default_val()); \
}); \
} \
template<> \
......@@ -134,7 +134,7 @@ OF_PP_FOR_EACH_TUPLE(ATTR_MEMBER_FUNC, ATTR_SEQ)
#undef ATTR_MEMBER_FUNC
OpRegistry& OpRegistry::DefaultedAttr(const std::string& name, UserOpAttrType type,
OpRegistry& OpRegistry::DefaultedAttr(const std::string& name, AttrType type,
const std::function<void(UserOpDef::AttrDef*)>& SetDefault) {
CHECK(InsertIfNotExists(name, &unique_names_));
AddAttrWithDefault(&result_, name, type, SetDefault);
......
......@@ -91,9 +91,9 @@ class OpRegistry final {
OpRegistry& SupportCpuOnly();
OpRegistry& SetOutputBufferNum(int32_t num);
__attribute__((deprecated)) OpRegistry& Attr(const std::string& name, UserOpAttrType type);
__attribute__((deprecated)) OpRegistry& Attr(const std::string& name, AttrType type);
template<typename T>
__attribute__((deprecated)) OpRegistry& Attr(const std::string& name, UserOpAttrType type,
__attribute__((deprecated)) OpRegistry& Attr(const std::string& name, AttrType type,
const T& default_val);
template<typename T>
OpRegistry& Attr(const std::string& name, const T& default_val);
......@@ -115,7 +115,7 @@ class OpRegistry final {
private:
OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional, int32_t num,
bool num_as_min);
OpRegistry& DefaultedAttr(const std::string& name, UserOpAttrType type,
OpRegistry& DefaultedAttr(const std::string& name, AttrType type,
const std::function<void(UserOpDef::AttrDef*)>& SetDefault);
private:
......
......@@ -112,7 +112,7 @@ message JobConfigProto {
optional int64 concurrency_width = 1000 [default = 128];
map<string, UserOpAttrVal> flag_name2flag_value = 2000;
map<string, AttrValue> flag_name2flag_value = 2000;
optional int64 logical_object_id = 3000;
}
......@@ -84,7 +84,7 @@ void JobDesc::Init() {
CheckFunctionConfig(job_conf_);
}
const UserOpAttrVal& JobDesc::GetFunctionFlagVal(const std::string& field_name) const {
const AttrValue& JobDesc::GetFunctionFlagVal(const std::string& field_name) const {
const auto& iter = job_conf_.flag_name2flag_value().find(field_name);
if (iter != job_conf_.flag_name2flag_value().end()) { return iter->second; }
const auto& flag_name2flag_def = GlobalFunctionConfigDef().flag_name2flag_def();
......
......@@ -70,11 +70,11 @@ class JobDesc final {
bool has_xrt_config() const { return job_conf_.has_xrt_config(); }
const XrtConfig& xrt_config() const { return job_conf_.xrt_config(); }
#define DEFINE_FUNCTION_CONFIG_GETTER(T, func_name, field_name) \
T func_name(const std::string& field_name) const { \
const UserOpAttrVal& attr_val = GetFunctionFlagVal(field_name); \
CHECK(attr_val.has_##field_name()); \
return attr_val.field_name(); \
#define DEFINE_FUNCTION_CONFIG_GETTER(T, func_name, field_name) \
T func_name(const std::string& field_name) const { \
const AttrValue& attr_val = GetFunctionFlagVal(field_name); \
CHECK(attr_val.has_##field_name()); \
return attr_val.field_name(); \
}
DEFINE_FUNCTION_CONFIG_GETTER(bool, Bool, at_bool);
DEFINE_FUNCTION_CONFIG_GETTER(int64_t, Int64, at_int64);
......@@ -88,7 +88,7 @@ class JobDesc final {
private:
void Init();
const UserOpAttrVal& GetFunctionFlagVal(const std::string& field_name) const;
const AttrValue& GetFunctionFlagVal(const std::string& field_name) const;
JobConfigProto job_conf_;
int64_t job_id_;
......
......@@ -14,5 +14,5 @@ message NodeDef {
required string op = 2;
repeated string input = 3;
optional string device = 4;
map<string, UserOpAttrVal> attr = 5;
map<string, AttrValue> attr = 5;
}
......@@ -239,7 +239,7 @@ Maybe<void> WriteInt8Calibration(const std::string& path) {
}
Maybe<long long> GetUserOpAttrType(const std::string& op_type_name, const std::string& attr_name) {
return JUST(GetUserOpAttrTypeImpl(op_type_name, attr_name));
return JUST(GetAttrTypeImpl(op_type_name, attr_name));
}
Maybe<std::string> CheckAndCompleteUserOpConf(const std::string& op_conf_str) {
......
......@@ -22,7 +22,7 @@ import oneflow.python.framework.hob as hob
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.lib.core.enable_if as enable_if
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.framework.user_op_attr_pb2 as user_op_attr_util
import oneflow.core.framework.user_op_attr_pb2 as attr_value_pb
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.core.common.shape_pb2 as shape_util
import oneflow
......@@ -325,52 +325,52 @@ class UserOpConfBuilder(object):
)
print(traceback.format_stack()[-2])
attribute = user_op_attr_util.UserOpAttrVal()
attribute = attr_value_pb.AttrValue()
assert isinstance(attr_name, str)
attr_type = c_api_util.GetUserOpAttrType(
self.user_op_.op_conf_.user_conf.op_type_name, attr_name
)
if attr_type == user_op_attr_util.kAtInt32:
if attr_type == attr_value_pb.kAtInt32:
assert isinstance(attr_value, int)
attribute.at_int32 = attr_value
elif attr_type == user_op_attr_util.kAtInt64:
elif attr_type == attr_value_pb.kAtInt64:
assert isinstance(attr_value, int)
attribute.at_int64 = attr_value
elif attr_type == user_op_attr_util.kAtBool:
elif attr_type == attr_value_pb.kAtBool:
assert isinstance(attr_value, bool)
attribute.at_bool = attr_value
elif attr_type == user_op_attr_util.kAtFloat:
elif attr_type == attr_value_pb.kAtFloat:
assert isinstance(attr_value, float)
attribute.at_float = attr_value
elif attr_type == user_op_attr_util.kAtDouble:
elif attr_type == attr_value_pb.kAtDouble:
assert isinstance(attr_value, float)
attribute.at_double = attr_value
elif attr_type == user_op_attr_util.kAtString:
elif attr_type == attr_value_pb.kAtString:
assert isinstance(attr_value, str)
attribute.at_string = attr_value
elif attr_type == user_op_attr_util.kAtShape:
elif attr_type == attr_value_pb.kAtShape:
assert isinstance(attr_value, (tuple, list))
assert all(isinstance(x, int) for x in attr_value)
attribute.at_shape.dim[:] = list(attr_value)
elif attr_type == user_op_attr_util.kAtDataType:
elif attr_type == attr_value_pb.kAtDataType:
assert (
isinstance(attr_value.oneflow_proto_dtype, int)
and attr_value in oneflow.dtypes()
)
attribute.at_data_type = attr_value.oneflow_proto_dtype
elif attr_type == user_op_attr_util.kAtListInt32:
elif attr_type == attr_value_pb.kAtListInt32:
assert isinstance(attr_value, (tuple, list))
assert all(isinstance(x, int) for x in attr_value)
attribute.at_list_int32.val[:] = list(attr_value)
elif attr_type == user_op_attr_util.kAtListInt64:
elif attr_type == attr_value_pb.kAtListInt64:
assert isinstance(attr_value, (tuple, list))
assert all(isinstance(x, int) for x in attr_value)
attribute.at_list_int64.val[:] = list(attr_value)
elif attr_type == user_op_attr_util.kAtListFloat:
elif attr_type == attr_value_pb.kAtListFloat:
assert isinstance(attr_value, (tuple, list))
assert all(isinstance(x, float) for x in attr_value)
attribute.at_list_float.val[:] = list(attr_value)
elif attr_type == user_op_attr_util.kAtListDataType:
elif attr_type == attr_value_pb.kAtListDataType:
assert isinstance(attr_value, (tuple, list))
assert all(
isinstance(x.oneflow_proto_dtype, int) and x in oneflow.dtypes()
......@@ -379,14 +379,14 @@ class UserOpConfBuilder(object):
attribute.at_list_data_type.val[:] = list(
[x.oneflow_proto_dtype for x in attr_value]
)
elif attr_type == user_op_attr_util.kAtListShape:
elif attr_type == attr_value_pb.kAtListShape:
assert isinstance(attr_value, (tuple, list))
assert all(isinstance(x, tuple) or isinstance(x, list) for x in attr_value)
for i in range(len(attr_value)):
shape = shape_util.ShapeProto()
shape.dim[:] = list(attr_value[i])
attribute.at_list_shape.val.append(shape)
elif attr_type == user_op_attr_util.kAtListString:
elif attr_type == attr_value_pb.kAtListString:
assert isinstance(attr_value, (tuple, list))
assert all(isinstance(x, str) for x in attr_value)
attribute.at_list_string.val[:] = list(attr_value)
......
......@@ -23,7 +23,7 @@ REGISTER_USER_OP("ssp_variable_proxy")
.Input("var")
.Output("ref")
.Output("value")
.Attr<int64_t>("buffer_size", UserOpAttrType::kAtInt64, 1)
.Attr<int64_t>("buffer_size", 1)
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const Shape* var_shape = ctx->Shape4ArgNameAndIndex("var", 0);
*ctx->Shape4ArgNameAndIndex("ref", 0) = *var_shape;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册