提交 5b35740b 编写于 作者: Y yejianwu

refactor arg_helper

上级 e8d613ef
......@@ -20,112 +20,80 @@
namespace mace {
ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
ProtoArgHelper::ProtoArgHelper(const OperatorDef &def) {
for (auto &arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) {
LOG(WARNING) << "Duplicated argument name found in operator def: "
<< def.name() << " " << arg.name();
if (arg_map_.count(arg.name())) {
LOG(WARNING) << "Duplicated argument " << arg.name()
<< " found in operator " << def.name();
}
arg_map_[arg.name()] = arg;
}
}
ArgumentHelper::ArgumentHelper(const NetDef &netdef) {
ProtoArgHelper::ProtoArgHelper(const NetDef &netdef) {
for (auto &arg : netdef.arg()) {
MACE_CHECK(arg_map_.count(arg.name()) == 0,
"Duplicated argument name found in net def.");
"Duplicated argument found in net def.");
arg_map_[arg.name()] = arg;
}
}
bool ArgumentHelper::HasArgument(const std::string &name) const {
return arg_map_.count(name);
}
namespace {
// Helper function to verify that conversion between types won't loose any
// significant bit.
template <typename InputType, typename TargetType>
bool SupportsLosslessConversion(const InputType &value) {
inline bool IsCastLossless(const InputType &value) {
return static_cast<InputType>(static_cast<TargetType>(value)) == value;
}
}
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \
enforce_lossless_conversion) \
template <> \
T ArgumentHelper::GetSingleArgument<T>(const std::string &name, \
const T &default_value) const { \
if (arg_map_.count(name) == 0) { \
VLOG(3) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
return default_value; \
} \
MACE_CHECK(arg_map_.at(name).has_##fieldname(), "Argument ", name, \
" does not have the right field: expected field " #fieldname); \
auto value = arg_map_.at(name).fieldname(); \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \
MACE_CHECK(supportsConversion, "Value", value, " of argument ", name, \
"cannot be represented correctly in a target type"); \
} \
return value; \
} \
template <> \
bool ArgumentHelper::HasSingleArgumentOfType<T>( \
const std::string &name) const { \
if (arg_map_.count(name) == 0) { \
return false; \
} \
return arg_map_.at(name).has_##fieldname(); \
#define MACE_GET_OPTIONAL_ARGUMENT_FUNC(T, fieldname, lossless_conversion) \
template <> \
T ProtoArgHelper::GetOptionalArg<T>(const std::string &arg_name, \
const T &default_value) const { \
if (arg_map_.count(arg_name) == 0) { \
VLOG(3) << "Using default parameter " << default_value << " for " \
<< arg_name; \
return default_value; \
} \
MACE_CHECK(arg_map_.at(arg_name).has_##fieldname(), "Argument ", arg_name, \
" not found!"); \
auto value = arg_map_.at(arg_name).fieldname(); \
if (lossless_conversion) { \
const bool castLossless = IsCastLossless<decltype(value), T>(value); \
MACE_CHECK(castLossless, "Value", value, " of argument ", arg_name, \
"cannot be casted losslessly to a target type"); \
} \
return value; \
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(std::string, s, false)
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
MACE_GET_OPTIONAL_ARGUMENT_FUNC(float, f, false)
MACE_GET_OPTIONAL_ARGUMENT_FUNC(bool, i, false)
MACE_GET_OPTIONAL_ARGUMENT_FUNC(int, i, true)
MACE_GET_OPTIONAL_ARGUMENT_FUNC(std::string, s, false)
#undef MACE_GET_OPTIONAL_ARGUMENT_FUNC
#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname, \
enforce_lossless_conversion) \
template <> \
std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
const std::string &name, const std::vector<T> &default_value) const { \
if (arg_map_.count(name) == 0) { \
return default_value; \
} \
std::vector<T> values; \
for (const auto &v : arg_map_.at(name).fieldname()) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
MACE_CHECK(supportsConversion, "Value", v, " of argument ", name, \
"cannot be represented correctly in a target type"); \
} \
values.push_back(v); \
} \
return values; \
#define MACE_GET_REPEATED_ARGUMENT_FUNC(T, fieldname, lossless_conversion) \
template <> \
std::vector<T> ProtoArgHelper::GetRepeatedArgs<T>( \
const std::string &arg_name, const std::vector<T> &default_value) \
const { \
if (arg_map_.count(arg_name) == 0) { \
return default_value; \
} \
std::vector<T> values; \
for (const auto &v : arg_map_.at(arg_name).fieldname()) { \
if (lossless_conversion) { \
const bool castLossless = IsCastLossless<decltype(v), T>(v); \
MACE_CHECK(castLossless, "Value", v, " of argument ", arg_name, \
"cannot be casted losslessly to a target type"); \
} \
values.push_back(v); \
} \
return values; \
}
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
MACE_GET_REPEATED_ARGUMENT_FUNC(float, floats, false)
MACE_GET_REPEATED_ARGUMENT_FUNC(int, ints, true)
MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true)
#undef MACE_GET_REPEATED_ARGUMENT_FUNC
} // namespace mace
......@@ -15,61 +15,41 @@
#ifndef MACE_CORE_ARG_HELPER_H_
#define MACE_CORE_ARG_HELPER_H_
#include <map>
#include <string>
#include <vector>
#include <map>
#include "mace/proto/mace.pb.h"
#include "mace/public/mace.h"
namespace mace {
/**
* @brief A helper class to index into arguments.
*
* This helper helps us to more easily index into a set of arguments
* that are present in the operator. To save memory, the argument helper
* does not copy the operator def, so one would need to make sure that the
* lifetime of the OperatorDef object outlives that of the ArgumentHelper.
*/
class ArgumentHelper {
// Refer to caffe2
class ProtoArgHelper {
public:
template <typename Def>
static bool HasArgument(const Def &def, const std::string &name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T>
static T GetSingleArgument(const Def &def,
const std::string &name,
const T &default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
static T GetOptionalArg(const Def &def,
const std::string &arg_name,
const T &default_value) {
return ProtoArgHelper(def).GetOptionalArg<T>(arg_name, default_value);
}
template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def &def, const std::string &name) {
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static std::vector<T> GetRepeatedArgument(
static std::vector<T> GetRepeatedArgs(
const Def &def,
const std::string &name,
const std::string &arg_name,
const std::vector<T> &default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
return ProtoArgHelper(def).GetRepeatedArgs<T>(arg_name, default_value);
}
explicit ArgumentHelper(const OperatorDef &def);
explicit ArgumentHelper(const NetDef &netdef);
bool HasArgument(const std::string &name) const;
explicit ProtoArgHelper(const OperatorDef &def);
explicit ProtoArgHelper(const NetDef &netdef);
template <typename T>
T GetSingleArgument(const std::string &name, const T &default_value) const;
template <typename T>
bool HasSingleArgumentOfType(const std::string &name) const;
T GetOptionalArg(const std::string &arg_name, const T &default_value) const;
template <typename T>
std::vector<T> GetRepeatedArgument(
const std::string &name,
std::vector<T> GetRepeatedArgs(
const std::string &arg_name,
const std::vector<T> &default_value = std::vector<T>()) const;
private:
......
......@@ -213,7 +213,7 @@ class Buffer : public BufferBase {
void *mapped_buf_;
bool is_data_owner_;
DISABLE_COPY_AND_ASSIGN(Buffer);
MACE_DISABLE_COPY_AND_ASSIGN(Buffer);
};
class Image : public BufferBase {
......@@ -330,7 +330,7 @@ class Image : public BufferBase {
void *buf_;
void *mapped_buf_;
DISABLE_COPY_AND_ASSIGN(Image);
MACE_DISABLE_COPY_AND_ASSIGN(Image);
};
class BufferSlice : public BufferBase {
......
......@@ -110,7 +110,7 @@ class MaceEngine::Impl {
std::unique_ptr<HexagonControlWrapper> hexagon_controller_;
#endif
DISABLE_COPY_AND_ASSIGN(Impl);
MACE_DISABLE_COPY_AND_ASSIGN(Impl);
};
MaceEngine::Impl::Impl(DeviceType device_type)
......@@ -146,7 +146,7 @@ MaceStatus MaceEngine::Impl::Init(
hexagon_controller_->SetDebugLevel(
static_cast<int>(mace::logging::LogMessage::MinVLogLevel()));
int dsp_mode =
ArgumentHelper::GetSingleArgument<NetDef, int>(*net_def, "dsp_mode", 0);
ProtoArgHelper::GetOptionalArg<NetDef, int>(*net_def, "dsp_mode", 0);
hexagon_controller_->SetGraphMode(dsp_mode);
MACE_CHECK(hexagon_controller_->SetupGraph(*net_def, model_data),
"hexagon setup graph error");
......
......@@ -42,7 +42,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
const auto &operator_def = net_def->op(idx);
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "device", static_cast<int>(device_type_));
if (op_device == type) {
VLOG(3) << "Creating operator " << operator_def.name() << "("
......@@ -97,12 +97,12 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
type.compare("FusedConv2D") == 0 ||
type.compare("DepthwiseConv2d") == 0 ||
type.compare("Pooling") == 0) {
strides = op->GetRepeatedArgument<int>("strides");
padding_type = op->GetSingleArgument<int>("padding", -1);
paddings = op->GetRepeatedArgument<int>("padding_values");
dilations = op->GetRepeatedArgument<int>("dilations");
strides = op->GetRepeatedArgs<int>("strides");
padding_type = op->GetOptionalArg<int>("padding", -1);
paddings = op->GetRepeatedArgs<int>("padding_values");
dilations = op->GetRepeatedArgs<int>("dilations");
if (type.compare("Pooling") == 0) {
kernels = op->GetRepeatedArgument<index_t>("kernels");
kernels = op->GetRepeatedArgs<index_t>("kernels");
} else {
kernels = op->Input(1)->shape();
}
......
......@@ -44,7 +44,7 @@ class NetBase {
std::string name_;
const std::shared_ptr<const OperatorRegistry> op_registry_;
DISABLE_COPY_AND_ASSIGN(NetBase);
MACE_DISABLE_COPY_AND_ASSIGN(NetBase);
};
class SerialNet : public NetBase {
......@@ -61,7 +61,7 @@ class SerialNet : public NetBase {
std::vector<std::unique_ptr<OperatorBase> > operators_;
DeviceType device_type_;
DISABLE_COPY_AND_ASSIGN(SerialNet);
MACE_DISABLE_COPY_AND_ASSIGN(SerialNet);
};
std::unique_ptr<NetBase> CreateNet(
......
......@@ -55,9 +55,9 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
Workspace *ws,
DeviceType type,
const NetMode mode) const {
const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
const int dtype = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "T", static_cast<int>(DT_FLOAT));
const int op_mode_i = ArgumentHelper::GetSingleArgument<OperatorDef, int>(
const int op_mode_i = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "mode", static_cast<int>(NetMode::NORMAL));
const NetMode op_mode = static_cast<NetMode>(op_mode_i);
if (op_mode == mode) {
......
......@@ -35,28 +35,18 @@ class OperatorBase {
explicit OperatorBase(const OperatorDef &operator_def, Workspace *ws);
virtual ~OperatorBase() noexcept {}
inline bool HasArgument(const std::string &name) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
}
template <typename T>
inline T GetSingleArgument(const std::string &name,
const T &default_value) const {
inline T GetOptionalArg(const std::string &name,
const T &default_value) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
return ProtoArgHelper::GetOptionalArg<OperatorDef, T>(
*operator_def_, name, default_value);
}
template <typename T>
inline bool HasSingleArgumentOfType(const std::string &name) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name);
}
template <typename T>
inline std::vector<T> GetRepeatedArgument(
inline std::vector<T> GetRepeatedArgs(
const std::string &name, const std::vector<T> &default_value = {}) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
return ProtoArgHelper::GetRepeatedArgs<OperatorDef, T>(
*operator_def_, name, default_value);
}
......@@ -93,7 +83,7 @@ class OperatorBase {
std::vector<const Tensor *> inputs_;
std::vector<Tensor *> outputs_;
DISABLE_COPY_AND_ASSIGN(OperatorBase);
MACE_DISABLE_COPY_AND_ASSIGN(OperatorBase);
};
template <DeviceType D, class T>
......@@ -188,7 +178,7 @@ class OperatorRegistry {
private:
RegistryType registry_;
DISABLE_COPY_AND_ASSIGN(OperatorRegistry);
MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistry);
};
MACE_DECLARE_REGISTRY(OpRegistry,
......
......@@ -51,7 +51,7 @@ class Registry {
std::map<SrcType, Creator> registry_;
std::mutex register_mutex_;
DISABLE_COPY_AND_ASSIGN(Registry);
MACE_DISABLE_COPY_AND_ASSIGN(Registry);
};
template <class SrcType, class ObjectType, class... Args>
......
......@@ -61,7 +61,7 @@ class HexagonControlWrapper {
uint32_t num_inputs_;
uint32_t num_outputs_;
DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper);
MACE_DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper);
};
} // namespace mace
......
......@@ -47,7 +47,7 @@ class Quantizer {
float *stepsize,
float *recip_stepsize);
DISABLE_COPY_AND_ASSIGN(Quantizer);
MACE_DISABLE_COPY_AND_ASSIGN(Quantizer);
};
} // namespace mace
......
......@@ -348,7 +348,7 @@ class Tensor {
const Tensor *tensor_;
std::vector<size_t> mapped_image_pitch_;
DISABLE_COPY_AND_ASSIGN(MappingGuard);
MACE_DISABLE_COPY_AND_ASSIGN(MappingGuard);
};
private:
......@@ -361,7 +361,7 @@ class Tensor {
bool is_buffer_owner_;
std::string name_;
DISABLE_COPY_AND_ASSIGN(Tensor);
MACE_DISABLE_COPY_AND_ASSIGN(Tensor);
};
} // namespace mace
......
......@@ -136,11 +136,11 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
for (auto &op : net_def.op()) {
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT)));
if (op_dtype != DataType::DT_INVALID) {
dtype = op_dtype;
......@@ -182,7 +182,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
for (auto &op : net_def.op()) {
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) {
auto mem_ids = op.mem_id();
......
......@@ -65,7 +65,7 @@ class Workspace {
std::unique_ptr<ScratchBuffer> host_scratch_buffer_;
DISABLE_COPY_AND_ASSIGN(Workspace);
MACE_DISABLE_COPY_AND_ASSIGN(Workspace);
};
} // namespace mace
......
......@@ -29,9 +29,9 @@ class ActivationOp : public Operator<D, T> {
ActivationOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
static_cast<T>(OperatorBase::GetSingleArgument<float>(
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
static_cast<T>(OperatorBase::GetOptionalArg<float>(
"max_limit", 0.0f))) {}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -28,8 +28,8 @@ class BatchNormOp : public Operator<D, T> {
BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(false, kernels::ActivationType::NOOP, 0.0f) {
epsilon_ = OperatorBase::GetSingleArgument<float>("epsilon",
static_cast<float>(1e-4));
epsilon_ = OperatorBase::GetOptionalArg<float>("epsilon",
static_cast<float>(1e-4));
}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -29,8 +29,8 @@ class BatchToSpaceNDOp : public Operator<D, T> {
public:
BatchToSpaceNDOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
functor_(OperatorBase::GetRepeatedArgs<int>("crops", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgs<int>("block_shape", {1, 1}),
true) {}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -31,7 +31,7 @@ class BufferToImageOp : public Operator<D, T> {
const Tensor *input_tensor = this->Input(INPUT);
kernels::BufferType type =
static_cast<kernels::BufferType>(OperatorBase::GetSingleArgument<int>(
static_cast<kernels::BufferType>(OperatorBase::GetOptionalArg<int>(
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
Tensor *output = this->Output(OUTPUT);
......
......@@ -28,7 +28,7 @@ class ChannelShuffleOp : public Operator<D, T> {
public:
ChannelShuffleOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
group_(OperatorBase::GetSingleArgument<int>("group", 1)),
group_(OperatorBase::GetOptionalArg<int>("group", 1)),
functor_(this->group_) {}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -28,13 +28,13 @@ class ConcatOp : public Operator<D, T> {
public:
ConcatOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->InputSize() >= 2)
<< "There must be at least two inputs to concat";
const std::vector<const Tensor *> input_list = this->Inputs();
const int32_t concat_axis = OperatorBase::GetSingleArgument<int>("axis", 3);
const int32_t concat_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
const int32_t input_dims = input_list[0]->dim_size();
const int32_t axis =
concat_axis < 0 ? concat_axis + input_dims : concat_axis;
......
......@@ -35,10 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_,
this->dilations_.data(),
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
static_cast<bool>(OperatorBase::GetSingleArgument<int>(
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f),
static_cast<bool>(OperatorBase::GetOptionalArg<int>(
"is_filter_transformed", false)),
ws->GetScratchBuffer(D)) {}
......
......@@ -28,12 +28,12 @@ class ConvPool2dOpBase : public Operator<D, T> {
public:
ConvPool2dOpBase(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
strides_(OperatorBase::GetRepeatedArgument<int>("strides")),
padding_type_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
strides_(OperatorBase::GetRepeatedArgs<int>("strides")),
padding_type_(static_cast<Padding>(OperatorBase::GetOptionalArg<int>(
"padding", static_cast<int>(SAME)))),
paddings_(OperatorBase::GetRepeatedArgument<int>("padding_values")),
paddings_(OperatorBase::GetRepeatedArgs<int>("padding_values")),
dilations_(
OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {}
OperatorBase::GetRepeatedArgs<int>("dilations", {1, 1})) {}
protected:
std::vector<int> strides_;
......
......@@ -32,7 +32,7 @@ class Deconv2dOp : public ConvPool2dOpBase<D, T> {
functor_(this->strides_.data(),
this->padding_type_,
this->paddings_,
OperatorBase::GetRepeatedArgument<index_t>("output_shape"),
OperatorBase::GetRepeatedArgs<index_t>("output_shape"),
kernels::ActivationType::NOOP,
0.0f) {}
......
......@@ -29,7 +29,7 @@ class DepthToSpaceOp : public Operator<D, T> {
public:
DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
block_size_(OperatorBase::GetSingleArgument<int>("block_size", 1)),
block_size_(OperatorBase::GetOptionalArg<int>("block_size", 1)),
functor_(this->block_size_, true) {}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -36,9 +36,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_,
this->dilations_.data(),
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -27,10 +27,10 @@ class EltwiseOp : public Operator<D, T> {
EltwiseOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(static_cast<kernels::EltwiseType>(
OperatorBase::GetSingleArgument<int>(
OperatorBase::GetOptionalArg<int>(
"type", static_cast<int>(kernels::EltwiseType::NONE))),
OperatorBase::GetRepeatedArgument<float>("coeff"),
OperatorBase::GetSingleArgument<float>("x", 1.0)) {}
OperatorBase::GetRepeatedArgs<float>("coeff"),
OperatorBase::GetOptionalArg<float>("x", 1.0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor* input0 = this->Input(0);
......
......@@ -30,9 +30,9 @@ class FoldedBatchNormOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws),
functor_(true,
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -29,9 +29,9 @@ class FullyConnectedOp : public Operator<D, T> {
FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -32,7 +32,7 @@ class ImageToBufferOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
kernels::BufferType type =
static_cast<kernels::BufferType>(OperatorBase::GetSingleArgument<int>(
static_cast<kernels::BufferType>(OperatorBase::GetOptionalArg<int>(
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
return functor_(input, type, output, future);
}
......
......@@ -27,10 +27,10 @@ class LocalResponseNormOp : public Operator<D, T> {
LocalResponseNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_() {
depth_radius_ = OperatorBase::GetSingleArgument<int>("depth_radius", 5);
bias_ = OperatorBase::GetSingleArgument<float>("bias", 1.0f);
alpha_ = OperatorBase::GetSingleArgument<float>("alpha", 1.0f);
beta_ = OperatorBase::GetSingleArgument<float>("beta", 0.5f);
depth_radius_ = OperatorBase::GetOptionalArg<int>("depth_radius", 5);
bias_ = OperatorBase::GetOptionalArg<float>("bias", 1.0f);
alpha_ = OperatorBase::GetOptionalArg<float>("alpha", 1.0f);
beta_ = OperatorBase::GetOptionalArg<float>("beta", 0.5f);
}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -28,8 +28,8 @@ class PadOp : public Operator<D, T> {
public:
PadOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("paddings"),
OperatorBase::GetSingleArgument<float>("constant_value", 0.0))
functor_(OperatorBase::GetRepeatedArgs<int>("paddings"),
OperatorBase::GetOptionalArg<float>("constant_value", 0.0))
{}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -29,9 +29,9 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
public:
PoolingOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws),
kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")),
kernels_(OperatorBase::GetRepeatedArgs<int>("kernels")),
pooling_type_(
static_cast<PoolingType>(OperatorBase::GetSingleArgument<int>(
static_cast<PoolingType>(OperatorBase::GetOptionalArg<int>(
"pooling_type", static_cast<int>(AVG)))),
functor_(pooling_type_,
kernels_.data(),
......
......@@ -26,14 +26,14 @@ class ProposalOp : public Operator<D, T> {
public:
ProposalOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("min_size", 16),
OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7),
OperatorBase::GetSingleArgument<int>("pre_nms_top_n", 6000),
OperatorBase::GetSingleArgument<int>("post_nms_top_n", 300),
OperatorBase::GetSingleArgument<int>("feat_stride", 0),
OperatorBase::GetSingleArgument<int>("base_size", 12),
OperatorBase::GetRepeatedArgument<int>("scales"),
OperatorBase::GetRepeatedArgument<float>("ratios")) {}
functor_(OperatorBase::GetOptionalArg<int>("min_size", 16),
OperatorBase::GetOptionalArg<float>("nms_thresh", 0.7),
OperatorBase::GetOptionalArg<int>("pre_nms_top_n", 6000),
OperatorBase::GetOptionalArg<int>("post_nms_top_n", 300),
OperatorBase::GetOptionalArg<int>("feat_stride", 0),
OperatorBase::GetOptionalArg<int>("base_size", 12),
OperatorBase::GetRepeatedArgs<int>("scales"),
OperatorBase::GetRepeatedArgs<float>("ratios")) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *rpn_cls_prob = this->Input(RPN_CLS_PROB);
......
......@@ -26,9 +26,9 @@ class PSROIAlignOp : public Operator<D, T> {
public:
PSROIAlignOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<T>("spatial_scale", 0),
OperatorBase::GetSingleArgument<int>("output_dim", 0),
OperatorBase::GetSingleArgument<int>("group_size", 0)) {}
functor_(OperatorBase::GetOptionalArg<T>("spatial_scale", 0),
OperatorBase::GetOptionalArg<int>("output_dim", 0),
OperatorBase::GetOptionalArg<int>("group_size", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -28,7 +28,7 @@ class ReshapeOp : public Operator<D, T> {
public:
ReshapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {}
shape_(OperatorBase::GetRepeatedArgs<int64_t>("shape")) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -27,8 +27,8 @@ class ResizeBilinearOp : public Operator<D, T> {
ResizeBilinearOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(
OperatorBase::GetRepeatedArgument<index_t>("size", {-1, -1}),
OperatorBase::GetSingleArgument<bool>("align_corners", false)) {}
OperatorBase::GetRepeatedArgs<index_t>("size", {-1, -1}),
OperatorBase::GetOptionalArg<bool>("align_corners", false)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(0);
......
......@@ -28,14 +28,14 @@ class SliceOp : public Operator<D, T> {
public:
SliceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->OutputSize() >= 2)
<< "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs();
const int32_t slice_axis = OperatorBase::GetSingleArgument<int>("axis", 3);
const int32_t slice_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally.";
......
......@@ -30,8 +30,8 @@ class SpaceToBatchNDOp : public Operator<D, T> {
SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(
OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
OperatorBase::GetRepeatedArgs<int>("paddings", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgs<int>("block_shape", {1, 1}),
false) {}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -29,7 +29,7 @@ class SpaceToDepthOp : public Operator<D, T> {
public:
SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) {
functor_(OperatorBase::GetOptionalArg<int>("block_size", 1), false) {
}
MaceStatus Run(StatsFuture *future) override {
......@@ -37,7 +37,7 @@ class SpaceToDepthOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size =
OperatorBase::GetSingleArgument<int>("block_size", 1);
OperatorBase::GetOptionalArg<int>("block_size", 1);
index_t input_height;
index_t input_width;
index_t input_depth;
......
......@@ -28,7 +28,7 @@ class TransposeOp : public Operator<D, T> {
public:
TransposeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
dims_(OperatorBase::GetRepeatedArgument<int>("dims")),
dims_(OperatorBase::GetRepeatedArgs<int>("dims")),
functor_(dims_) {}
MaceStatus Run(StatsFuture *future) override {
......
......@@ -30,13 +30,13 @@ class WinogradInverseTransformOp : public Operator<D, T> {
public:
WinogradInverseTransformOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("batch", 1),
OperatorBase::GetSingleArgument<int>("height", 0),
OperatorBase::GetSingleArgument<int>("width", 0),
functor_(OperatorBase::GetOptionalArg<int>("batch", 1),
OperatorBase::GetOptionalArg<int>("height", 0),
OperatorBase::GetOptionalArg<int>("width", 0),
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")),
OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
......
......@@ -28,9 +28,9 @@ class WinogradTransformOp : public Operator<D, T> {
public:
WinogradTransformOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>(
functor_(static_cast<Padding>(OperatorBase::GetOptionalArg<int>(
"padding", static_cast<int>(VALID))),
OperatorBase::GetRepeatedArgument<int>("padding_values")) {}
OperatorBase::GetRepeatedArgs<int>("padding_values")) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
......
......@@ -4,8 +4,11 @@ package mace;
option optimize_for = LITE_RUNTIME;
// For better compatibility,
// the mace.proto is refered from tensorflow and caffe2.
enum NetMode {
INIT = 0;
INIT = 0;
NORMAL = 1;
}
......@@ -64,7 +67,7 @@ message OperatorDef {
optional uint32 op_id = 101;
optional uint32 padding = 102;
repeated NodeInput node_input = 103;
repeated int32 out_max_byte_size = 104; // only support 32-bit len
repeated int32 out_max_byte_size = 104; // only support 32-bit len
}
// for memory optimization
......@@ -82,14 +85,14 @@ message InputInfo {
optional string name = 1;
optional int32 node_id = 2;
repeated int32 dims = 3;
optional int32 max_byte_size = 4; // only support 32-bit len
optional int32 max_byte_size = 4; // only support 32-bit len
optional DataType data_type = 5 [default = DT_FLOAT];
}
message OutputInfo {
optional string name = 1;
optional int32 node_id = 2;
repeated int32 dims = 3;
optional int32 max_byte_size = 4; // only support 32-bit len
optional int32 max_byte_size = 4; // only support 32-bit len
optional DataType data_type = 5 [default = DT_FLOAT];
}
......
......@@ -117,7 +117,7 @@ class LatencyLogger {
const std::string message_;
int64_t start_micros_;
DISABLE_COPY_AND_ASSIGN(LatencyLogger);
MACE_DISABLE_COPY_AND_ASSIGN(LatencyLogger);
};
#define MACE_LATENCY_LOGGER(vlog_level, ...) \
......
......@@ -58,7 +58,7 @@ class WallClockTimer : public Timer {
double stop_micros_;
double accumulated_micros_;
DISABLE_COPY_AND_ASSIGN(WallClockTimer);
MACE_DISABLE_COPY_AND_ASSIGN(WallClockTimer);
};
} // namespace mace
......
......@@ -24,11 +24,11 @@
namespace mace {
// Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \
private: \
classname(const classname &) = delete; \
classname &operator=(const classname &) = delete
#ifndef MACE_DISABLE_COPY_AND_ASSIGN
#define MACE_DISABLE_COPY_AND_ASSIGN(CLASSNAME) \
private: \
CLASSNAME(const CLASSNAME &) = delete; \
CLASSNAME &operator=(const CLASSNAME &) = delete
#endif
template <typename Integer>
......@@ -132,7 +132,7 @@ inline std::vector<std::string> Split(const std::string &str, char delims) {
}
inline bool ReadBinaryFile(std::vector<unsigned char> *data,
const std::string &filename) {
const std::string &filename) {
std::ifstream ifs(filename, std::ios::in | std::ios::binary);
if (!ifs.is_open()) {
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册