diff --git a/mace/core/net.cc b/mace/core/net.cc index 93922891629108a74f3652e890e514ed4155e44a..40312f3d66fdd4b989cc9f3c5d7b42f7b8bb86bf 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -18,7 +18,7 @@ NetBase::NetBase(const std::shared_ptr &net_def, SimpleNet::SimpleNet(const std::shared_ptr &net_def, Workspace *ws, DeviceType type, - const OpMode mode) + const NetMode mode) : NetBase(net_def, ws, type), device_type_(type){ VLOG(1) << "Constructing SimpleNet " << net_def->name(); for (int idx = 0; idx < net_def->op_size(); ++idx) { @@ -93,7 +93,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { unique_ptr CreateNet(const NetDef &net_def, Workspace *ws, DeviceType type, - const OpMode mode) { + const NetMode mode) { std::shared_ptr tmp_net_def(new NetDef(net_def)); return CreateNet(tmp_net_def, ws, type, mode); } @@ -101,7 +101,7 @@ unique_ptr CreateNet(const NetDef &net_def, unique_ptr CreateNet(const std::shared_ptr &net_def, Workspace *ws, DeviceType type, - const OpMode mode) { + const NetMode mode) { unique_ptr net(new SimpleNet(net_def, ws, type, mode)); return net; } diff --git a/mace/core/net.h b/mace/core/net.h index 73aad694ef4149aad97f4d1f943e43ecf33d66e3..67c954f3e59c9bc1f8c8c46a6ce23858f94c1675 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -35,7 +35,7 @@ class SimpleNet : public NetBase { SimpleNet(const std::shared_ptr &net_def, Workspace *ws, DeviceType type, - const OpMode mode = OpMode::NORMAL); + const NetMode mode = NetMode::NORMAL); bool Run(RunMetadata *run_metadata = nullptr) override; @@ -49,11 +49,11 @@ class SimpleNet : public NetBase { unique_ptr CreateNet(const NetDef &net_def, Workspace *ws, DeviceType type, - const OpMode mode = OpMode::NORMAL); + const NetMode mode = NetMode::NORMAL); unique_ptr CreateNet(const std::shared_ptr &net_def, Workspace *ws, DeviceType type, - const OpMode mode = OpMode::NORMAL); + const NetMode mode = NetMode::NORMAL); } // namespace mace diff --git a/mace/core/operator.cc b/mace/core/operator.cc index b29efc59a484abe7f75d2ee7dac009156483c47a..2026105439b96b27bc90bb2b867f7914a66e31dc 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -50,15 +50,15 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry); unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, DeviceType type, - const OpMode mode) { + const NetMode mode) { OperatorRegistry *registry = gDeviceTypeRegistry()->at(type); const int dtype = ArgumentHelper::GetSingleArgument(operator_def, "T", static_cast(DT_FLOAT)); const int op_mode_i= ArgumentHelper::GetSingleArgument(operator_def, "mode", - static_cast(OpMode::NORMAL)); - const OpMode op_mode = static_cast(op_mode_i); + static_cast(NetMode::NORMAL)); + const NetMode op_mode = static_cast(op_mode_i); if (op_mode == mode) { return registry->Create(OpKeyBuilder(operator_def.type().data()) .TypeConstraint("T", static_cast(dtype)) diff --git a/mace/core/operator.h b/mace/core/operator.h index 11754f65dcb54876c4cce27155c5244bdb1c5da3..137a2e1a002be96673adb6cbbf5e30df13a09b01 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -196,7 +196,7 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry, unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, DeviceType type, - const OpMode mode); + const NetMode mode); } // namespace mace diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 61dd160d14519d59510865971c190ba6ffd2160a..21eade9c8deba8ddd441b193e990e68793ac8f7a 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -101,7 +101,7 @@ int main(int argc, char **argv) { } // Init model - auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT); + auto net = CreateNet(net_def, &ws, device_type, NetMode::INIT); net->Run(); // run model diff --git a/mace/ops/core_test.cc b/mace/ops/core_test.cc index 27efc095fa02df6e54ce073f598f41cf9761500f..8761d04fba494c3d39cdd196dfd05c01c692f687 100644 --- a/mace/ops/core_test.cc +++ b/mace/ops/core_test.cc @@ -17,7 +17,7 @@ TEST(CoreTest, INIT_MODE) { .Input("Input") .Output("B2IOutput") .AddIntArg("buffer_type", kernels::BufferType::FILTER) - .AddIntArg("mode", static_cast(OpMode::INIT)) + .AddIntArg("mode", static_cast(NetMode::INIT)) .Finalize(&op_defs[op_defs.size()-1]); Tensor *input = @@ -40,7 +40,7 @@ TEST(CoreTest, INIT_MODE) { for (auto &op_def : op_defs) { net_def.add_op()->CopyFrom(op_def); } - auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT); + auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, NetMode::INIT); net->Run(); EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr); diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index ab6be22b90cc563472dd3924ddc637a8909a3f2a..37a349433e75cc277186da49a5942ed4d78503bf 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -2,7 +2,7 @@ syntax = "proto2"; package mace; -enum OpMode { +enum NetMode { INIT = 0; NORMAL = 1; }