提交 264a0160 编写于 作者: 李寅

Merge branch 'refactor_proto_and_relevant_macros' into 'master'

refactor mace.proto and relevant macros

See merge request !523
......@@ -126,7 +126,6 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(std::string, strings, false)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
} // namespace mace
......@@ -26,13 +26,9 @@ static cl_channel_type DataTypeToCLChannelType(const DataType t) {
return CL_HALF_FLOAT;
case DT_FLOAT:
return CL_FLOAT;
case DT_INT8:
case DT_INT16:
case DT_INT32:
return CL_SIGNED_INT32;
case DT_UINT8:
case DT_UINT16:
case DT_UINT32:
return CL_UNSIGNED_INT32;
default:
LOG(FATAL) << "Image doesn't support the data type: " << t;
......
......@@ -37,60 +37,49 @@
namespace mace {
#define SINGLE_ARG(...) __VA_ARGS__
#define CASE(TYPE, STMTS) \
#define MACE_SINGLE_ARG(...) __VA_ARGS__
#define MACE_CASE(TYPE, STATEMENTS) \
case DataTypeToEnum<TYPE>::value: { \
typedef TYPE T; \
STMTS; \
STATEMENTS; \
break; \
}
#ifdef MACE_ENABLE_OPENCL
#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
switch (TYPE_ENUM) { \
CASE(half, SINGLE_ARG(STMTS)) \
CASE(float, SINGLE_ARG(STMTS)) \
CASE(double, SINGLE_ARG(STMTS)) \
CASE(int32_t, SINGLE_ARG(STMTS)) \
CASE(uint8_t, SINGLE_ARG(STMTS)) \
CASE(uint16_t, SINGLE_ARG(STMTS)) \
CASE(int16_t, SINGLE_ARG(STMTS)) \
CASE(int8_t, SINGLE_ARG(STMTS)) \
CASE(std::string, SINGLE_ARG(STMTS)) \
CASE(int64_t, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
INVALID; \
break; \
default: \
DEFAULT; \
break; \
#define MACE_TYPE_ENUM_SWITCH( \
TYPE_ENUM, STATEMENTS, INVALID_STATEMENTS, DEFAULT_STATEMENTS) \
switch (TYPE_ENUM) { \
MACE_CASE(half, MACE_SINGLE_ARG(STATEMENTS)) \
MACE_CASE(float, MACE_SINGLE_ARG(STATEMENTS)) \
MACE_CASE(uint8_t, MACE_SINGLE_ARG(STATEMENTS)) \
MACE_CASE(int32_t, MACE_SINGLE_ARG(STATEMENTS)) \
case DT_INVALID: \
INVALID_STATEMENTS; \
break; \
default: \
DEFAULT_STATEMENTS; \
break; \
}
#else
#define CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, INVALID, DEFAULT) \
switch (TYPE_ENUM) { \
CASE(float, SINGLE_ARG(STMTS)) \
CASE(double, SINGLE_ARG(STMTS)) \
CASE(int32_t, SINGLE_ARG(STMTS)) \
CASE(uint8_t, SINGLE_ARG(STMTS)) \
CASE(uint16_t, SINGLE_ARG(STMTS)) \
CASE(int16_t, SINGLE_ARG(STMTS)) \
CASE(int8_t, SINGLE_ARG(STMTS)) \
CASE(std::string, SINGLE_ARG(STMTS)) \
CASE(int64_t, SINGLE_ARG(STMTS)) \
CASE(bool, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
INVALID; \
break; \
default: \
DEFAULT; \
break; \
#define MACE_TYPE_ENUM_SWITCH( \
TYPE_ENUM, STATEMENTS, INVALID_STATEMENTS, DEFAULT_STATEMENTS) \
switch (TYPE_ENUM) { \
MACE_CASE(float, MACE_SINGLE_ARG(STATEMENTS)) \
MACE_CASE(uint8_t, MACE_SINGLE_ARG(STATEMENTS)) \
MACE_CASE(int32_t, MACE_SINGLE_ARG(STATEMENTS)) \
case DT_INVALID: \
INVALID_STATEMENTS; \
break; \
default: \
DEFAULT_STATEMENTS; \
break; \
}
#endif
#define CASES(TYPE_ENUM, STMTS) \
CASES_WITH_DEFAULT(TYPE_ENUM, STMTS, LOG(FATAL) << "Type not set"; \
, LOG(FATAL) << "Unexpected type: " << TYPE_ENUM;)
// `TYPE_ENUM` will be converted to template `T` in `STATEMENTS`
#define MACE_RUN_WITH_TYPE_ENUM(TYPE_ENUM, STATEMENTS) \
MACE_TYPE_ENUM_SWITCH(TYPE_ENUM, STATEMENTS, LOG(FATAL) << "Invalid type"; \
, LOG(FATAL) << "Unknown type: " << TYPE_ENUM;)
namespace numerical_chars {
inline std::ostream &operator<<(std::ostream &os, char c) {
......@@ -307,7 +296,7 @@ class Tensor {
inline size_t SizeOfType() const {
size_t type_size = 0;
CASES(dtype_, type_size = sizeof(T));
MACE_RUN_WITH_TYPE_ENUM(dtype_, type_size = sizeof(T));
return type_size;
}
......@@ -328,7 +317,7 @@ class Tensor {
if (i != 0 && i % shape_.back() == 0) {
os << "\n";
}
CASES(dtype_, (os << (this->data<T>()[i]) << ", "));
MACE_RUN_WITH_TYPE_ENUM(dtype_, (os << (this->data<T>()[i]) << ", "));
}
LOG(INFO) << os.str();
}
......
......@@ -23,15 +23,8 @@ namespace mace {
bool DataTypeCanUseMemcpy(DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_DOUBLE:
case DT_INT32:
case DT_INT64:
case DT_UINT32:
case DT_UINT16:
case DT_UINT8:
case DT_INT16:
case DT_INT8:
case DT_BOOL:
case DT_INT32:
return true;
default:
return false;
......@@ -44,15 +37,8 @@ std::string DataTypeToString(const DataType dt) {
#ifdef MACE_ENABLE_OPENCL
{DT_HALF, "DT_HALF"},
#endif
{DT_DOUBLE, "DT_DOUBLE"},
{DT_UINT8, "DT_UINT8"},
{DT_INT8, "DT_INT8"},
{DT_INT32, "DT_INT32"},
{DT_UINT32, "DT_UINT32"},
{DT_UINT16, "DT_UINT16"},
{DT_INT64, "DT_INT64"},
{DT_BOOL, "DT_BOOL"},
{DT_STRING, "DT_STRING"}};
{DT_INT32, "DT_UINT32"}};
MACE_CHECK(dt != DT_INVALID) << "Not support Invalid data type";
return dtype_string_map[dt];
}
......@@ -67,22 +53,10 @@ size_t GetEnumTypeSize(const DataType dt) {
#endif
case DT_UINT8:
return sizeof(uint8_t);
case DT_INT8:
return sizeof(int8_t);
case DT_DOUBLE:
return sizeof(double);
case DT_INT32:
return sizeof(int32_t);
case DT_UINT32:
return sizeof(uint32_t);
case DT_UINT16:
return sizeof(uint16_t);
case DT_INT16:
return sizeof(int16_t);
case DT_INT64:
return sizeof(int64_t);
default:
LOG(FATAL) << "Unsupported data type";
LOG(FATAL) << "Unsupported data type: " << dt;
return 0;
}
}
......
......@@ -38,50 +38,28 @@ size_t GetEnumTypeSize(const DataType dt);
std::string DataTypeToString(const DataType dt);
template <class T>
struct IsValidDataType;
struct DataTypeToEnum;
template <class T>
struct DataTypeToEnum {
static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
};
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
// EnumToDataType<DT_FLOAT>::Type is float.
template <DataType VALUE>
struct EnumToDataType {}; // Specializations below
// Template specialization for both DataTypeToEnum and EnumToDataType.
#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
template <> \
struct DataTypeToEnum<TYPE> { \
static DataType v() { return ENUM; } \
static constexpr DataType value = ENUM; \
}; \
template <> \
struct IsValidDataType<TYPE> { \
static constexpr bool value = true; \
}; \
template <> \
struct EnumToDataType<ENUM> { \
typedef TYPE Type; \
}
struct EnumToDataType;
#define MACE_MAPPING_DATA_TYPE_AND_ENUM(DATA_TYPE, ENUM_VALUE) \
template <> \
struct DataTypeToEnum<DATA_TYPE> { \
static DataType v() { return ENUM_VALUE; } \
static constexpr DataType value = ENUM_VALUE; \
}; \
template <> \
struct EnumToDataType<ENUM_VALUE> { \
typedef DATA_TYPE Type; \
};
#ifdef MACE_ENABLE_OPENCL
MATCH_TYPE_AND_ENUM(half, DT_HALF);
MACE_MAPPING_DATA_TYPE_AND_ENUM(half, DT_HALF);
#endif
MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
MATCH_TYPE_AND_ENUM(int32_t, DT_INT32);
MATCH_TYPE_AND_ENUM(uint16_t, DT_UINT16);
MATCH_TYPE_AND_ENUM(uint8_t, DT_UINT8);
MATCH_TYPE_AND_ENUM(int16_t, DT_INT16);
MATCH_TYPE_AND_ENUM(int8_t, DT_INT8);
MATCH_TYPE_AND_ENUM(std::string, DT_STRING);
MATCH_TYPE_AND_ENUM(int64_t, DT_INT64);
MATCH_TYPE_AND_ENUM(uint32_t, DT_UINT32);
MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
static const int32_t kint32_tmax = ((int32_t)0x7FFFFFFF);
MACE_MAPPING_DATA_TYPE_AND_ENUM(float, DT_FLOAT);
MACE_MAPPING_DATA_TYPE_AND_ENUM(uint8_t, DT_UINT8);
MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32);
} // namespace mace
#endif // MACE_CORE_TYPES_H_
......@@ -100,16 +100,6 @@ class OpDefBuilder {
return *this;
}
OpDefBuilder AddStringsArg(const std::string &name,
const std::vector<const char *> &values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_strings(value);
}
return *this;
}
void Finalize(OperatorDef *op_def) const {
MACE_CHECK(op_def != nullptr, "input should not be null.");
*op_def = op_def_;
......
......@@ -44,7 +44,7 @@ void ResizeBilinearBenchmark(int iters,
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddInputFromArray<D, index_t>("OutSize", {2},
net.AddInputFromArray<D, int>("OutSize", {2},
{output_height, output_width});
if (D == DeviceType::CPU) {
......
......@@ -11,45 +11,20 @@ enum NetMode {
enum DataType {
DT_INVALID = 0;
// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_INT64 = 8;
DT_UINT16 = 9;
DT_BOOL = 10;
DT_HALF = 19;
DT_UINT32 = 22;
DT_UINT8 = 2;
DT_HALF = 3;
DT_INT32 = 4;
}
message ConstTensor {
// The dimensions in the tensor.
repeated int64 dims = 1;
optional DataType data_type = 2 [default = DT_FLOAT];
// For float
repeated float float_data = 3 [packed = true];
// For int32, uint8, int8, uint16, int16, bool, and float16
// Note about float16: in storage we will basically convert float16 byte-wise
// to unsigned short and then store them in the int32_data field.
repeated int32 int32_data = 4 [packed = true];
// For bytes
optional bytes byte_data = 5;
// For strings
repeated bytes string_data = 6;
// For double
repeated double double_data = 9 [packed = true];
// For int64
repeated int64 int64_data = 10 [packed = true];
// Optionally, a name for the tensor.
optional string name = 7;
optional int64 offset = 11;
optional int64 data_size = 12;
optional string name = 5;
optional int64 offset = 6;
optional int64 data_size = 7;
optional uint32 node_id = 100;
}
......@@ -61,7 +36,6 @@ message Argument {
optional bytes s = 4;
repeated float floats = 5;
repeated int64 ints = 6;
repeated bytes strings = 7;
}
// for hexagon mace-nnlib
......
......@@ -18,21 +18,11 @@ from mace.proto import mace_pb2
TF_DTYPE_2_MACE_DTYPE_MAP = {
tf.float32: mace_pb2.DT_FLOAT,
tf.double: mace_pb2.DT_DOUBLE,
tf.half: mace_pb2.DT_HALF,
tf.int64: mace_pb2.DT_INT64,
tf.int32: mace_pb2.DT_INT32,
tf.qint32: mace_pb2.DT_INT32,
tf.int16: mace_pb2.DT_INT16,
tf.qint16: mace_pb2.DT_INT16,
tf.int8: mace_pb2.DT_INT8,
tf.qint8: mace_pb2.DT_INT8,
tf.quint16: mace_pb2.DT_UINT16,
tf.uint16: mace_pb2.DT_UINT16,
tf.quint8: mace_pb2.DT_UINT8,
tf.uint8: mace_pb2.DT_UINT8,
tf.string: mace_pb2.DT_STRING,
tf.bool: mace_pb2.DT_BOOL,
}
......
......@@ -68,10 +68,6 @@ void CreateNetArg(NetDef *net_def) {
{% for int_value in net.arg[i].ints %}
arg->add_ints({{ int_value }});
{% endfor %}
arg->mutable_strings()->Reserve({{ net.arg[i].strings|length }});
{% for str_value in net.arg[i].strings %}
arg->add_strings({{ str_value }});
{% endfor %}
{% endfor %}
}
......
......@@ -91,10 +91,6 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
{% for int_value in arg.ints %}
arg->add_ints({{ int_value }});
{% endfor %}
arg->mutable_strings()->Reserve({{ arg.strings|length }});
{% for str_value in arg.strings %}
arg->add_strings({{ str_value }});
{% endfor %}
{% endfor %}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册