提交 5b35740b 编写于 作者: Y yejianwu

refactor arg_helper

上级 e8d613ef
...@@ -20,112 +20,80 @@ ...@@ -20,112 +20,80 @@
namespace mace { namespace mace {
ArgumentHelper::ArgumentHelper(const OperatorDef &def) { ProtoArgHelper::ProtoArgHelper(const OperatorDef &def) {
for (auto &arg : def.arg()) { for (auto &arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) { if (arg_map_.count(arg.name())) {
LOG(WARNING) << "Duplicated argument name found in operator def: " LOG(WARNING) << "Duplicated argument " << arg.name()
<< def.name() << " " << arg.name(); << " found in operator " << def.name();
} }
arg_map_[arg.name()] = arg; arg_map_[arg.name()] = arg;
} }
} }
ArgumentHelper::ArgumentHelper(const NetDef &netdef) { ProtoArgHelper::ProtoArgHelper(const NetDef &netdef) {
for (auto &arg : netdef.arg()) { for (auto &arg : netdef.arg()) {
MACE_CHECK(arg_map_.count(arg.name()) == 0, MACE_CHECK(arg_map_.count(arg.name()) == 0,
"Duplicated argument name found in net def."); "Duplicated argument found in net def.");
arg_map_[arg.name()] = arg; arg_map_[arg.name()] = arg;
} }
} }
bool ArgumentHelper::HasArgument(const std::string &name) const {
return arg_map_.count(name);
}
namespace { namespace {
// Helper function to verify that conversion between types won't loose any
// significant bit.
template <typename InputType, typename TargetType> template <typename InputType, typename TargetType>
bool SupportsLosslessConversion(const InputType &value) { inline bool IsCastLossless(const InputType &value) {
return static_cast<InputType>(static_cast<TargetType>(value)) == value; return static_cast<InputType>(static_cast<TargetType>(value)) == value;
} }
} }
#define INSTANTIATE_GET_SINGLE_ARGUMENT(T, fieldname, \ #define MACE_GET_OPTIONAL_ARGUMENT_FUNC(T, fieldname, lossless_conversion) \
enforce_lossless_conversion) \
template <> \ template <> \
T ArgumentHelper::GetSingleArgument<T>(const std::string &name, \ T ProtoArgHelper::GetOptionalArg<T>(const std::string &arg_name, \
const T &default_value) const { \ const T &default_value) const { \
if (arg_map_.count(name) == 0) { \ if (arg_map_.count(arg_name) == 0) { \
VLOG(3) << "Using default parameter value " << default_value \ VLOG(3) << "Using default parameter " << default_value << " for " \
<< " for parameter " << name; \ << arg_name; \
return default_value; \ return default_value; \
} \ } \
MACE_CHECK(arg_map_.at(name).has_##fieldname(), "Argument ", name, \ MACE_CHECK(arg_map_.at(arg_name).has_##fieldname(), "Argument ", arg_name, \
" does not have the right field: expected field " #fieldname); \ " not found!"); \
auto value = arg_map_.at(name).fieldname(); \ auto value = arg_map_.at(arg_name).fieldname(); \
if (enforce_lossless_conversion) { \ if (lossless_conversion) { \
auto supportsConversion = \ const bool castLossless = IsCastLossless<decltype(value), T>(value); \
SupportsLosslessConversion<decltype(value), T>(value); \ MACE_CHECK(castLossless, "Value", value, " of argument ", arg_name, \
MACE_CHECK(supportsConversion, "Value", value, " of argument ", name, \ "cannot be casted losslessly to a target type"); \
"cannot be represented correctly in a target type"); \
} \ } \
return value; \ return value; \
} \
template <> \
bool ArgumentHelper::HasSingleArgumentOfType<T>( \
const std::string &name) const { \
if (arg_map_.count(name) == 0) { \
return false; \
} \
return arg_map_.at(name).has_##fieldname(); \
} }
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false) MACE_GET_OPTIONAL_ARGUMENT_FUNC(float, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false) MACE_GET_OPTIONAL_ARGUMENT_FUNC(bool, i, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false) MACE_GET_OPTIONAL_ARGUMENT_FUNC(int, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true) MACE_GET_OPTIONAL_ARGUMENT_FUNC(std::string, s, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true) #undef MACE_GET_OPTIONAL_ARGUMENT_FUNC
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(std::string, s, false)
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
#define INSTANTIATE_GET_REPEATED_ARGUMENT(T, fieldname, \ #define MACE_GET_REPEATED_ARGUMENT_FUNC(T, fieldname, lossless_conversion) \
enforce_lossless_conversion) \
template <> \ template <> \
std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \ std::vector<T> ProtoArgHelper::GetRepeatedArgs<T>( \
const std::string &name, const std::vector<T> &default_value) const { \ const std::string &arg_name, const std::vector<T> &default_value) \
if (arg_map_.count(name) == 0) { \ const { \
if (arg_map_.count(arg_name) == 0) { \
return default_value; \ return default_value; \
} \ } \
std::vector<T> values; \ std::vector<T> values; \
for (const auto &v : arg_map_.at(name).fieldname()) { \ for (const auto &v : arg_map_.at(arg_name).fieldname()) { \
if (enforce_lossless_conversion) { \ if (lossless_conversion) { \
auto supportsConversion = \ const bool castLossless = IsCastLossless<decltype(v), T>(v); \
SupportsLosslessConversion<decltype(v), T>(v); \ MACE_CHECK(castLossless, "Value", v, " of argument ", arg_name, \
MACE_CHECK(supportsConversion, "Value", v, " of argument ", name, \ "cannot be casted losslessly to a target type"); \
"cannot be represented correctly in a target type"); \
} \ } \
values.push_back(v); \ values.push_back(v); \
} \ } \
return values; \ return values; \
} }
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false) MACE_GET_REPEATED_ARGUMENT_FUNC(float, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false) MACE_GET_REPEATED_ARGUMENT_FUNC(int, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false) MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true) #undef MACE_GET_REPEATED_ARGUMENT_FUNC
INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
} // namespace mace } // namespace mace
...@@ -15,61 +15,41 @@ ...@@ -15,61 +15,41 @@
#ifndef MACE_CORE_ARG_HELPER_H_ #ifndef MACE_CORE_ARG_HELPER_H_
#define MACE_CORE_ARG_HELPER_H_ #define MACE_CORE_ARG_HELPER_H_
#include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map>
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
namespace mace { namespace mace {
/** // Refer to caffe2
* @brief A helper class to index into arguments. class ProtoArgHelper {
*
* This helper helps us to more easily index into a set of arguments
* that are present in the operator. To save memory, the argument helper
* does not copy the operator def, so one would need to make sure that the
* lifetime of the OperatorDef object outlives that of the ArgumentHelper.
*/
class ArgumentHelper {
public: public:
template <typename Def>
static bool HasArgument(const Def &def, const std::string &name) {
return ArgumentHelper(def).HasArgument(name);
}
template <typename Def, typename T> template <typename Def, typename T>
static T GetSingleArgument(const Def &def, static T GetOptionalArg(const Def &def,
const std::string &name, const std::string &arg_name,
const T &default_value) { const T &default_value) {
return ArgumentHelper(def).GetSingleArgument<T>(name, default_value); return ProtoArgHelper(def).GetOptionalArg<T>(arg_name, default_value);
} }
template <typename Def, typename T> template <typename Def, typename T>
static bool HasSingleArgumentOfType(const Def &def, const std::string &name) { static std::vector<T> GetRepeatedArgs(
return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
}
template <typename Def, typename T>
static std::vector<T> GetRepeatedArgument(
const Def &def, const Def &def,
const std::string &name, const std::string &arg_name,
const std::vector<T> &default_value = std::vector<T>()) { const std::vector<T> &default_value = std::vector<T>()) {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value); return ProtoArgHelper(def).GetRepeatedArgs<T>(arg_name, default_value);
} }
explicit ArgumentHelper(const OperatorDef &def); explicit ProtoArgHelper(const OperatorDef &def);
explicit ArgumentHelper(const NetDef &netdef); explicit ProtoArgHelper(const NetDef &netdef);
bool HasArgument(const std::string &name) const;
template <typename T> template <typename T>
T GetSingleArgument(const std::string &name, const T &default_value) const; T GetOptionalArg(const std::string &arg_name, const T &default_value) const;
template <typename T>
bool HasSingleArgumentOfType(const std::string &name) const;
template <typename T> template <typename T>
std::vector<T> GetRepeatedArgument( std::vector<T> GetRepeatedArgs(
const std::string &name, const std::string &arg_name,
const std::vector<T> &default_value = std::vector<T>()) const; const std::vector<T> &default_value = std::vector<T>()) const;
private: private:
......
...@@ -213,7 +213,7 @@ class Buffer : public BufferBase { ...@@ -213,7 +213,7 @@ class Buffer : public BufferBase {
void *mapped_buf_; void *mapped_buf_;
bool is_data_owner_; bool is_data_owner_;
DISABLE_COPY_AND_ASSIGN(Buffer); MACE_DISABLE_COPY_AND_ASSIGN(Buffer);
}; };
class Image : public BufferBase { class Image : public BufferBase {
...@@ -330,7 +330,7 @@ class Image : public BufferBase { ...@@ -330,7 +330,7 @@ class Image : public BufferBase {
void *buf_; void *buf_;
void *mapped_buf_; void *mapped_buf_;
DISABLE_COPY_AND_ASSIGN(Image); MACE_DISABLE_COPY_AND_ASSIGN(Image);
}; };
class BufferSlice : public BufferBase { class BufferSlice : public BufferBase {
......
...@@ -110,7 +110,7 @@ class MaceEngine::Impl { ...@@ -110,7 +110,7 @@ class MaceEngine::Impl {
std::unique_ptr<HexagonControlWrapper> hexagon_controller_; std::unique_ptr<HexagonControlWrapper> hexagon_controller_;
#endif #endif
DISABLE_COPY_AND_ASSIGN(Impl); MACE_DISABLE_COPY_AND_ASSIGN(Impl);
}; };
MaceEngine::Impl::Impl(DeviceType device_type) MaceEngine::Impl::Impl(DeviceType device_type)
...@@ -146,7 +146,7 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -146,7 +146,7 @@ MaceStatus MaceEngine::Impl::Init(
hexagon_controller_->SetDebugLevel( hexagon_controller_->SetDebugLevel(
static_cast<int>(mace::logging::LogMessage::MinVLogLevel())); static_cast<int>(mace::logging::LogMessage::MinVLogLevel()));
int dsp_mode = int dsp_mode =
ArgumentHelper::GetSingleArgument<NetDef, int>(*net_def, "dsp_mode", 0); ProtoArgHelper::GetOptionalArg<NetDef, int>(*net_def, "dsp_mode", 0);
hexagon_controller_->SetGraphMode(dsp_mode); hexagon_controller_->SetGraphMode(dsp_mode);
MACE_CHECK(hexagon_controller_->SetupGraph(*net_def, model_data), MACE_CHECK(hexagon_controller_->SetupGraph(*net_def, model_data),
"hexagon setup graph error"); "hexagon setup graph error");
......
...@@ -42,7 +42,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry, ...@@ -42,7 +42,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
const auto &operator_def = net_def->op(idx); const auto &operator_def = net_def->op(idx);
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
const int op_device = const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "device", static_cast<int>(device_type_)); operator_def, "device", static_cast<int>(device_type_));
if (op_device == type) { if (op_device == type) {
VLOG(3) << "Creating operator " << operator_def.name() << "(" VLOG(3) << "Creating operator " << operator_def.name() << "("
...@@ -97,12 +97,12 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { ...@@ -97,12 +97,12 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
type.compare("FusedConv2D") == 0 || type.compare("FusedConv2D") == 0 ||
type.compare("DepthwiseConv2d") == 0 || type.compare("DepthwiseConv2d") == 0 ||
type.compare("Pooling") == 0) { type.compare("Pooling") == 0) {
strides = op->GetRepeatedArgument<int>("strides"); strides = op->GetRepeatedArgs<int>("strides");
padding_type = op->GetSingleArgument<int>("padding", -1); padding_type = op->GetOptionalArg<int>("padding", -1);
paddings = op->GetRepeatedArgument<int>("padding_values"); paddings = op->GetRepeatedArgs<int>("padding_values");
dilations = op->GetRepeatedArgument<int>("dilations"); dilations = op->GetRepeatedArgs<int>("dilations");
if (type.compare("Pooling") == 0) { if (type.compare("Pooling") == 0) {
kernels = op->GetRepeatedArgument<index_t>("kernels"); kernels = op->GetRepeatedArgs<index_t>("kernels");
} else { } else {
kernels = op->Input(1)->shape(); kernels = op->Input(1)->shape();
} }
......
...@@ -44,7 +44,7 @@ class NetBase { ...@@ -44,7 +44,7 @@ class NetBase {
std::string name_; std::string name_;
const std::shared_ptr<const OperatorRegistry> op_registry_; const std::shared_ptr<const OperatorRegistry> op_registry_;
DISABLE_COPY_AND_ASSIGN(NetBase); MACE_DISABLE_COPY_AND_ASSIGN(NetBase);
}; };
class SerialNet : public NetBase { class SerialNet : public NetBase {
...@@ -61,7 +61,7 @@ class SerialNet : public NetBase { ...@@ -61,7 +61,7 @@ class SerialNet : public NetBase {
std::vector<std::unique_ptr<OperatorBase> > operators_; std::vector<std::unique_ptr<OperatorBase> > operators_;
DeviceType device_type_; DeviceType device_type_;
DISABLE_COPY_AND_ASSIGN(SerialNet); MACE_DISABLE_COPY_AND_ASSIGN(SerialNet);
}; };
std::unique_ptr<NetBase> CreateNet( std::unique_ptr<NetBase> CreateNet(
......
...@@ -55,9 +55,9 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator( ...@@ -55,9 +55,9 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const NetMode mode) const { const NetMode mode) const {
const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>( const int dtype = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "T", static_cast<int>(DT_FLOAT)); operator_def, "T", static_cast<int>(DT_FLOAT));
const int op_mode_i = ArgumentHelper::GetSingleArgument<OperatorDef, int>( const int op_mode_i = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
operator_def, "mode", static_cast<int>(NetMode::NORMAL)); operator_def, "mode", static_cast<int>(NetMode::NORMAL));
const NetMode op_mode = static_cast<NetMode>(op_mode_i); const NetMode op_mode = static_cast<NetMode>(op_mode_i);
if (op_mode == mode) { if (op_mode == mode) {
......
...@@ -35,28 +35,18 @@ class OperatorBase { ...@@ -35,28 +35,18 @@ class OperatorBase {
explicit OperatorBase(const OperatorDef &operator_def, Workspace *ws); explicit OperatorBase(const OperatorDef &operator_def, Workspace *ws);
virtual ~OperatorBase() noexcept {} virtual ~OperatorBase() noexcept {}
inline bool HasArgument(const std::string &name) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
}
template <typename T> template <typename T>
inline T GetSingleArgument(const std::string &name, inline T GetOptionalArg(const std::string &name,
const T &default_value) const { const T &default_value) const {
MACE_CHECK(operator_def_, "operator_def was null!"); MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>( return ProtoArgHelper::GetOptionalArg<OperatorDef, T>(
*operator_def_, name, default_value); *operator_def_, name, default_value);
} }
template <typename T> template <typename T>
inline bool HasSingleArgumentOfType(const std::string &name) const { inline std::vector<T> GetRepeatedArgs(
MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
*operator_def_, name);
}
template <typename T>
inline std::vector<T> GetRepeatedArgument(
const std::string &name, const std::vector<T> &default_value = {}) const { const std::string &name, const std::vector<T> &default_value = {}) const {
MACE_CHECK(operator_def_, "operator_def was null!"); MACE_CHECK(operator_def_, "operator_def was null!");
return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>( return ProtoArgHelper::GetRepeatedArgs<OperatorDef, T>(
*operator_def_, name, default_value); *operator_def_, name, default_value);
} }
...@@ -93,7 +83,7 @@ class OperatorBase { ...@@ -93,7 +83,7 @@ class OperatorBase {
std::vector<const Tensor *> inputs_; std::vector<const Tensor *> inputs_;
std::vector<Tensor *> outputs_; std::vector<Tensor *> outputs_;
DISABLE_COPY_AND_ASSIGN(OperatorBase); MACE_DISABLE_COPY_AND_ASSIGN(OperatorBase);
}; };
template <DeviceType D, class T> template <DeviceType D, class T>
...@@ -188,7 +178,7 @@ class OperatorRegistry { ...@@ -188,7 +178,7 @@ class OperatorRegistry {
private: private:
RegistryType registry_; RegistryType registry_;
DISABLE_COPY_AND_ASSIGN(OperatorRegistry); MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistry);
}; };
MACE_DECLARE_REGISTRY(OpRegistry, MACE_DECLARE_REGISTRY(OpRegistry,
......
...@@ -51,7 +51,7 @@ class Registry { ...@@ -51,7 +51,7 @@ class Registry {
std::map<SrcType, Creator> registry_; std::map<SrcType, Creator> registry_;
std::mutex register_mutex_; std::mutex register_mutex_;
DISABLE_COPY_AND_ASSIGN(Registry); MACE_DISABLE_COPY_AND_ASSIGN(Registry);
}; };
template <class SrcType, class ObjectType, class... Args> template <class SrcType, class ObjectType, class... Args>
......
...@@ -61,7 +61,7 @@ class HexagonControlWrapper { ...@@ -61,7 +61,7 @@ class HexagonControlWrapper {
uint32_t num_inputs_; uint32_t num_inputs_;
uint32_t num_outputs_; uint32_t num_outputs_;
DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper); MACE_DISABLE_COPY_AND_ASSIGN(HexagonControlWrapper);
}; };
} // namespace mace } // namespace mace
......
...@@ -47,7 +47,7 @@ class Quantizer { ...@@ -47,7 +47,7 @@ class Quantizer {
float *stepsize, float *stepsize,
float *recip_stepsize); float *recip_stepsize);
DISABLE_COPY_AND_ASSIGN(Quantizer); MACE_DISABLE_COPY_AND_ASSIGN(Quantizer);
}; };
} // namespace mace } // namespace mace
......
...@@ -348,7 +348,7 @@ class Tensor { ...@@ -348,7 +348,7 @@ class Tensor {
const Tensor *tensor_; const Tensor *tensor_;
std::vector<size_t> mapped_image_pitch_; std::vector<size_t> mapped_image_pitch_;
DISABLE_COPY_AND_ASSIGN(MappingGuard); MACE_DISABLE_COPY_AND_ASSIGN(MappingGuard);
}; };
private: private:
...@@ -361,7 +361,7 @@ class Tensor { ...@@ -361,7 +361,7 @@ class Tensor {
bool is_buffer_owner_; bool is_buffer_owner_;
std::string name_; std::string name_;
DISABLE_COPY_AND_ASSIGN(Tensor); MACE_DISABLE_COPY_AND_ASSIGN(Tensor);
}; };
} // namespace mace } // namespace mace
......
...@@ -136,11 +136,11 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -136,11 +136,11 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
const int op_device = const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type)); op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) { if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>( const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT))); op, "T", static_cast<int>(DT_FLOAT)));
if (op_dtype != DataType::DT_INVALID) { if (op_dtype != DataType::DT_INVALID) {
dtype = op_dtype; dtype = op_dtype;
...@@ -182,7 +182,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -182,7 +182,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
for (auto &op : net_def.op()) { for (auto &op : net_def.op()) {
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
const int op_device = const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type)); op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) { if (op_device == device_type && !op.mem_id().empty()) {
auto mem_ids = op.mem_id(); auto mem_ids = op.mem_id();
......
...@@ -65,7 +65,7 @@ class Workspace { ...@@ -65,7 +65,7 @@ class Workspace {
std::unique_ptr<ScratchBuffer> host_scratch_buffer_; std::unique_ptr<ScratchBuffer> host_scratch_buffer_;
DISABLE_COPY_AND_ASSIGN(Workspace); MACE_DISABLE_COPY_AND_ASSIGN(Workspace);
}; };
} // namespace mace } // namespace mace
......
...@@ -29,9 +29,9 @@ class ActivationOp : public Operator<D, T> { ...@@ -29,9 +29,9 @@ class ActivationOp : public Operator<D, T> {
ActivationOp(const OperatorDef &operator_def, Workspace *ws) ActivationOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(kernels::StringToActivationType( functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
static_cast<T>(OperatorBase::GetSingleArgument<float>( static_cast<T>(OperatorBase::GetOptionalArg<float>(
"max_limit", 0.0f))) {} "max_limit", 0.0f))) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -28,7 +28,7 @@ class BatchNormOp : public Operator<D, T> { ...@@ -28,7 +28,7 @@ class BatchNormOp : public Operator<D, T> {
BatchNormOp(const OperatorDef &operator_def, Workspace *ws) BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(false, kernels::ActivationType::NOOP, 0.0f) { functor_(false, kernels::ActivationType::NOOP, 0.0f) {
epsilon_ = OperatorBase::GetSingleArgument<float>("epsilon", epsilon_ = OperatorBase::GetOptionalArg<float>("epsilon",
static_cast<float>(1e-4)); static_cast<float>(1e-4));
} }
......
...@@ -29,8 +29,8 @@ class BatchToSpaceNDOp : public Operator<D, T> { ...@@ -29,8 +29,8 @@ class BatchToSpaceNDOp : public Operator<D, T> {
public: public:
BatchToSpaceNDOp(const OperatorDef &op_def, Workspace *ws) BatchToSpaceNDOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0}), functor_(OperatorBase::GetRepeatedArgs<int>("crops", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}), OperatorBase::GetRepeatedArgs<int>("block_shape", {1, 1}),
true) {} true) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -31,7 +31,7 @@ class BufferToImageOp : public Operator<D, T> { ...@@ -31,7 +31,7 @@ class BufferToImageOp : public Operator<D, T> {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *input_tensor = this->Input(INPUT);
kernels::BufferType type = kernels::BufferType type =
static_cast<kernels::BufferType>(OperatorBase::GetSingleArgument<int>( static_cast<kernels::BufferType>(OperatorBase::GetOptionalArg<int>(
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER))); "buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
......
...@@ -28,7 +28,7 @@ class ChannelShuffleOp : public Operator<D, T> { ...@@ -28,7 +28,7 @@ class ChannelShuffleOp : public Operator<D, T> {
public: public:
ChannelShuffleOp(const OperatorDef &operator_def, Workspace *ws) ChannelShuffleOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
group_(OperatorBase::GetSingleArgument<int>("group", 1)), group_(OperatorBase::GetOptionalArg<int>("group", 1)),
functor_(this->group_) {} functor_(this->group_) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -28,13 +28,13 @@ class ConcatOp : public Operator<D, T> { ...@@ -28,13 +28,13 @@ class ConcatOp : public Operator<D, T> {
public: public:
ConcatOp(const OperatorDef &op_def, Workspace *ws) ConcatOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {} functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->InputSize() >= 2) MACE_CHECK(this->InputSize() >= 2)
<< "There must be at least two inputs to concat"; << "There must be at least two inputs to concat";
const std::vector<const Tensor *> input_list = this->Inputs(); const std::vector<const Tensor *> input_list = this->Inputs();
const int32_t concat_axis = OperatorBase::GetSingleArgument<int>("axis", 3); const int32_t concat_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
const int32_t input_dims = input_list[0]->dim_size(); const int32_t input_dims = input_list[0]->dim_size();
const int32_t axis = const int32_t axis =
concat_axis < 0 ? concat_axis + input_dims : concat_axis; concat_axis < 0 ? concat_axis + input_dims : concat_axis;
......
...@@ -35,10 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -35,10 +35,10 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_, this->paddings_,
this->dilations_.data(), this->dilations_.data(),
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetOptionalArg<float>("max_limit", 0.0f),
static_cast<bool>(OperatorBase::GetSingleArgument<int>( static_cast<bool>(OperatorBase::GetOptionalArg<int>(
"is_filter_transformed", false)), "is_filter_transformed", false)),
ws->GetScratchBuffer(D)) {} ws->GetScratchBuffer(D)) {}
......
...@@ -28,12 +28,12 @@ class ConvPool2dOpBase : public Operator<D, T> { ...@@ -28,12 +28,12 @@ class ConvPool2dOpBase : public Operator<D, T> {
public: public:
ConvPool2dOpBase(const OperatorDef &op_def, Workspace *ws) ConvPool2dOpBase(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
strides_(OperatorBase::GetRepeatedArgument<int>("strides")), strides_(OperatorBase::GetRepeatedArgs<int>("strides")),
padding_type_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>( padding_type_(static_cast<Padding>(OperatorBase::GetOptionalArg<int>(
"padding", static_cast<int>(SAME)))), "padding", static_cast<int>(SAME)))),
paddings_(OperatorBase::GetRepeatedArgument<int>("padding_values")), paddings_(OperatorBase::GetRepeatedArgs<int>("padding_values")),
dilations_( dilations_(
OperatorBase::GetRepeatedArgument<int>("dilations", {1, 1})) {} OperatorBase::GetRepeatedArgs<int>("dilations", {1, 1})) {}
protected: protected:
std::vector<int> strides_; std::vector<int> strides_;
......
...@@ -32,7 +32,7 @@ class Deconv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -32,7 +32,7 @@ class Deconv2dOp : public ConvPool2dOpBase<D, T> {
functor_(this->strides_.data(), functor_(this->strides_.data(),
this->padding_type_, this->padding_type_,
this->paddings_, this->paddings_,
OperatorBase::GetRepeatedArgument<index_t>("output_shape"), OperatorBase::GetRepeatedArgs<index_t>("output_shape"),
kernels::ActivationType::NOOP, kernels::ActivationType::NOOP,
0.0f) {} 0.0f) {}
......
...@@ -29,7 +29,7 @@ class DepthToSpaceOp : public Operator<D, T> { ...@@ -29,7 +29,7 @@ class DepthToSpaceOp : public Operator<D, T> {
public: public:
DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
block_size_(OperatorBase::GetSingleArgument<int>("block_size", 1)), block_size_(OperatorBase::GetOptionalArg<int>("block_size", 1)),
functor_(this->block_size_, true) {} functor_(this->block_size_, true) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -36,9 +36,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -36,9 +36,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_, this->paddings_,
this->dilations_.data(), this->dilations_.data(),
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {} OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -27,10 +27,10 @@ class EltwiseOp : public Operator<D, T> { ...@@ -27,10 +27,10 @@ class EltwiseOp : public Operator<D, T> {
EltwiseOp(const OperatorDef &op_def, Workspace *ws) EltwiseOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(static_cast<kernels::EltwiseType>( functor_(static_cast<kernels::EltwiseType>(
OperatorBase::GetSingleArgument<int>( OperatorBase::GetOptionalArg<int>(
"type", static_cast<int>(kernels::EltwiseType::NONE))), "type", static_cast<int>(kernels::EltwiseType::NONE))),
OperatorBase::GetRepeatedArgument<float>("coeff"), OperatorBase::GetRepeatedArgs<float>("coeff"),
OperatorBase::GetSingleArgument<float>("x", 1.0)) {} OperatorBase::GetOptionalArg<float>("x", 1.0)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor* input0 = this->Input(0); const Tensor* input0 = this->Input(0);
......
...@@ -30,9 +30,9 @@ class FoldedBatchNormOp : public Operator<D, T> { ...@@ -30,9 +30,9 @@ class FoldedBatchNormOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(true, functor_(true,
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {} OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -29,9 +29,9 @@ class FullyConnectedOp : public Operator<D, T> { ...@@ -29,9 +29,9 @@ class FullyConnectedOp : public Operator<D, T> {
FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws) FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(kernels::StringToActivationType( functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {} OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -32,7 +32,7 @@ class ImageToBufferOp : public Operator<D, T> { ...@@ -32,7 +32,7 @@ class ImageToBufferOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
kernels::BufferType type = kernels::BufferType type =
static_cast<kernels::BufferType>(OperatorBase::GetSingleArgument<int>( static_cast<kernels::BufferType>(OperatorBase::GetOptionalArg<int>(
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER))); "buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
return functor_(input, type, output, future); return functor_(input, type, output, future);
} }
......
...@@ -27,10 +27,10 @@ class LocalResponseNormOp : public Operator<D, T> { ...@@ -27,10 +27,10 @@ class LocalResponseNormOp : public Operator<D, T> {
LocalResponseNormOp(const OperatorDef &operator_def, Workspace *ws) LocalResponseNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_() { functor_() {
depth_radius_ = OperatorBase::GetSingleArgument<int>("depth_radius", 5); depth_radius_ = OperatorBase::GetOptionalArg<int>("depth_radius", 5);
bias_ = OperatorBase::GetSingleArgument<float>("bias", 1.0f); bias_ = OperatorBase::GetOptionalArg<float>("bias", 1.0f);
alpha_ = OperatorBase::GetSingleArgument<float>("alpha", 1.0f); alpha_ = OperatorBase::GetOptionalArg<float>("alpha", 1.0f);
beta_ = OperatorBase::GetSingleArgument<float>("beta", 0.5f); beta_ = OperatorBase::GetOptionalArg<float>("beta", 0.5f);
} }
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -28,8 +28,8 @@ class PadOp : public Operator<D, T> { ...@@ -28,8 +28,8 @@ class PadOp : public Operator<D, T> {
public: public:
PadOp(const OperatorDef &operator_def, Workspace *ws) PadOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("paddings"), functor_(OperatorBase::GetRepeatedArgs<int>("paddings"),
OperatorBase::GetSingleArgument<float>("constant_value", 0.0)) OperatorBase::GetOptionalArg<float>("constant_value", 0.0))
{} {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -29,9 +29,9 @@ class PoolingOp : public ConvPool2dOpBase<D, T> { ...@@ -29,9 +29,9 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
public: public:
PoolingOp(const OperatorDef &op_def, Workspace *ws) PoolingOp(const OperatorDef &op_def, Workspace *ws)
: ConvPool2dOpBase<D, T>(op_def, ws), : ConvPool2dOpBase<D, T>(op_def, ws),
kernels_(OperatorBase::GetRepeatedArgument<int>("kernels")), kernels_(OperatorBase::GetRepeatedArgs<int>("kernels")),
pooling_type_( pooling_type_(
static_cast<PoolingType>(OperatorBase::GetSingleArgument<int>( static_cast<PoolingType>(OperatorBase::GetOptionalArg<int>(
"pooling_type", static_cast<int>(AVG)))), "pooling_type", static_cast<int>(AVG)))),
functor_(pooling_type_, functor_(pooling_type_,
kernels_.data(), kernels_.data(),
......
...@@ -26,14 +26,14 @@ class ProposalOp : public Operator<D, T> { ...@@ -26,14 +26,14 @@ class ProposalOp : public Operator<D, T> {
public: public:
ProposalOp(const OperatorDef &operator_def, Workspace *ws) ProposalOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("min_size", 16), functor_(OperatorBase::GetOptionalArg<int>("min_size", 16),
OperatorBase::GetSingleArgument<float>("nms_thresh", 0.7), OperatorBase::GetOptionalArg<float>("nms_thresh", 0.7),
OperatorBase::GetSingleArgument<int>("pre_nms_top_n", 6000), OperatorBase::GetOptionalArg<int>("pre_nms_top_n", 6000),
OperatorBase::GetSingleArgument<int>("post_nms_top_n", 300), OperatorBase::GetOptionalArg<int>("post_nms_top_n", 300),
OperatorBase::GetSingleArgument<int>("feat_stride", 0), OperatorBase::GetOptionalArg<int>("feat_stride", 0),
OperatorBase::GetSingleArgument<int>("base_size", 12), OperatorBase::GetOptionalArg<int>("base_size", 12),
OperatorBase::GetRepeatedArgument<int>("scales"), OperatorBase::GetRepeatedArgs<int>("scales"),
OperatorBase::GetRepeatedArgument<float>("ratios")) {} OperatorBase::GetRepeatedArgs<float>("ratios")) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *rpn_cls_prob = this->Input(RPN_CLS_PROB); const Tensor *rpn_cls_prob = this->Input(RPN_CLS_PROB);
......
...@@ -26,9 +26,9 @@ class PSROIAlignOp : public Operator<D, T> { ...@@ -26,9 +26,9 @@ class PSROIAlignOp : public Operator<D, T> {
public: public:
PSROIAlignOp(const OperatorDef &operator_def, Workspace *ws) PSROIAlignOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<T>("spatial_scale", 0), functor_(OperatorBase::GetOptionalArg<T>("spatial_scale", 0),
OperatorBase::GetSingleArgument<int>("output_dim", 0), OperatorBase::GetOptionalArg<int>("output_dim", 0),
OperatorBase::GetSingleArgument<int>("group_size", 0)) {} OperatorBase::GetOptionalArg<int>("group_size", 0)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -28,7 +28,7 @@ class ReshapeOp : public Operator<D, T> { ...@@ -28,7 +28,7 @@ class ReshapeOp : public Operator<D, T> {
public: public:
ReshapeOp(const OperatorDef &op_def, Workspace *ws) ReshapeOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {} shape_(OperatorBase::GetRepeatedArgs<int64_t>("shape")) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -27,8 +27,8 @@ class ResizeBilinearOp : public Operator<D, T> { ...@@ -27,8 +27,8 @@ class ResizeBilinearOp : public Operator<D, T> {
ResizeBilinearOp(const OperatorDef &operator_def, Workspace *ws) ResizeBilinearOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_( functor_(
OperatorBase::GetRepeatedArgument<index_t>("size", {-1, -1}), OperatorBase::GetRepeatedArgs<index_t>("size", {-1, -1}),
OperatorBase::GetSingleArgument<bool>("align_corners", false)) {} OperatorBase::GetOptionalArg<bool>("align_corners", false)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(0); const Tensor *input = this->Input(0);
......
...@@ -28,14 +28,14 @@ class SliceOp : public Operator<D, T> { ...@@ -28,14 +28,14 @@ class SliceOp : public Operator<D, T> {
public: public:
SliceOp(const OperatorDef &op_def, Workspace *ws) SliceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {} functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->OutputSize() >= 2) MACE_CHECK(this->OutputSize() >= 2)
<< "There must be at least two outputs for slicing"; << "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs(); const std::vector<Tensor *> output_list = this->Outputs();
const int32_t slice_axis = OperatorBase::GetSingleArgument<int>("axis", 3); const int32_t slice_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0) MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally."; << "Outputs do not split input equally.";
......
...@@ -30,8 +30,8 @@ class SpaceToBatchNDOp : public Operator<D, T> { ...@@ -30,8 +30,8 @@ class SpaceToBatchNDOp : public Operator<D, T> {
SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws) SpaceToBatchNDOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_( functor_(
OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0}), OperatorBase::GetRepeatedArgs<int>("paddings", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}), OperatorBase::GetRepeatedArgs<int>("block_shape", {1, 1}),
false) {} false) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -29,7 +29,7 @@ class SpaceToDepthOp : public Operator<D, T> { ...@@ -29,7 +29,7 @@ class SpaceToDepthOp : public Operator<D, T> {
public: public:
SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws) SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) { functor_(OperatorBase::GetOptionalArg<int>("block_size", 1), false) {
} }
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
...@@ -37,7 +37,7 @@ class SpaceToDepthOp : public Operator<D, T> { ...@@ -37,7 +37,7 @@ class SpaceToDepthOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size = const int block_size =
OperatorBase::GetSingleArgument<int>("block_size", 1); OperatorBase::GetOptionalArg<int>("block_size", 1);
index_t input_height; index_t input_height;
index_t input_width; index_t input_width;
index_t input_depth; index_t input_depth;
......
...@@ -28,7 +28,7 @@ class TransposeOp : public Operator<D, T> { ...@@ -28,7 +28,7 @@ class TransposeOp : public Operator<D, T> {
public: public:
TransposeOp(const OperatorDef &operator_def, Workspace *ws) TransposeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
dims_(OperatorBase::GetRepeatedArgument<int>("dims")), dims_(OperatorBase::GetRepeatedArgs<int>("dims")),
functor_(dims_) {} functor_(dims_) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
......
...@@ -30,13 +30,13 @@ class WinogradInverseTransformOp : public Operator<D, T> { ...@@ -30,13 +30,13 @@ class WinogradInverseTransformOp : public Operator<D, T> {
public: public:
WinogradInverseTransformOp(const OperatorDef &op_def, Workspace *ws) WinogradInverseTransformOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("batch", 1), functor_(OperatorBase::GetOptionalArg<int>("batch", 1),
OperatorBase::GetSingleArgument<int>("height", 0), OperatorBase::GetOptionalArg<int>("height", 0),
OperatorBase::GetSingleArgument<int>("width", 0), OperatorBase::GetOptionalArg<int>("width", 0),
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {} OperatorBase::GetOptionalArg<float>("max_limit", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *input_tensor = this->Input(INPUT);
......
...@@ -28,9 +28,9 @@ class WinogradTransformOp : public Operator<D, T> { ...@@ -28,9 +28,9 @@ class WinogradTransformOp : public Operator<D, T> {
public: public:
WinogradTransformOp(const OperatorDef &op_def, Workspace *ws) WinogradTransformOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(static_cast<Padding>(OperatorBase::GetSingleArgument<int>( functor_(static_cast<Padding>(OperatorBase::GetOptionalArg<int>(
"padding", static_cast<int>(VALID))), "padding", static_cast<int>(VALID))),
OperatorBase::GetRepeatedArgument<int>("padding_values")) {} OperatorBase::GetRepeatedArgs<int>("padding_values")) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *input_tensor = this->Input(INPUT);
......
...@@ -4,6 +4,9 @@ package mace; ...@@ -4,6 +4,9 @@ package mace;
option optimize_for = LITE_RUNTIME; option optimize_for = LITE_RUNTIME;
// For better compatibility,
// the mace.proto is refered from tensorflow and caffe2.
enum NetMode { enum NetMode {
INIT = 0; INIT = 0;
NORMAL = 1; NORMAL = 1;
......
...@@ -117,7 +117,7 @@ class LatencyLogger { ...@@ -117,7 +117,7 @@ class LatencyLogger {
const std::string message_; const std::string message_;
int64_t start_micros_; int64_t start_micros_;
DISABLE_COPY_AND_ASSIGN(LatencyLogger); MACE_DISABLE_COPY_AND_ASSIGN(LatencyLogger);
}; };
#define MACE_LATENCY_LOGGER(vlog_level, ...) \ #define MACE_LATENCY_LOGGER(vlog_level, ...) \
......
...@@ -58,7 +58,7 @@ class WallClockTimer : public Timer { ...@@ -58,7 +58,7 @@ class WallClockTimer : public Timer {
double stop_micros_; double stop_micros_;
double accumulated_micros_; double accumulated_micros_;
DISABLE_COPY_AND_ASSIGN(WallClockTimer); MACE_DISABLE_COPY_AND_ASSIGN(WallClockTimer);
}; };
} // namespace mace } // namespace mace
......
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
namespace mace { namespace mace {
// Disable the copy and assignment operator for a class. // Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN #ifndef MACE_DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \ #define MACE_DISABLE_COPY_AND_ASSIGN(CLASSNAME) \
private: \ private: \
classname(const classname &) = delete; \ CLASSNAME(const CLASSNAME &) = delete; \
classname &operator=(const classname &) = delete CLASSNAME &operator=(const CLASSNAME &) = delete
#endif #endif
template <typename Integer> template <typename Integer>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册