提交 062fa1e7 编写于 作者: L liuqi

Rename opmode to netmode.

上级 336bdbed
...@@ -18,7 +18,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def, ...@@ -18,7 +18,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode) const NetMode mode)
: NetBase(net_def, ws, type), device_type_(type){ : NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name(); VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
...@@ -93,7 +93,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { ...@@ -93,7 +93,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
unique_ptr<NetBase> CreateNet(const NetDef &net_def, unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode) { const NetMode mode) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def)); std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type, mode); return CreateNet(tmp_net_def, ws, type, mode);
} }
...@@ -101,7 +101,7 @@ unique_ptr<NetBase> CreateNet(const NetDef &net_def, ...@@ -101,7 +101,7 @@ unique_ptr<NetBase> CreateNet(const NetDef &net_def,
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def, unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode) { const NetMode mode) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type, mode)); unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type, mode));
return net; return net;
} }
......
...@@ -35,7 +35,7 @@ class SimpleNet : public NetBase { ...@@ -35,7 +35,7 @@ class SimpleNet : public NetBase {
SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode = OpMode::NORMAL); const NetMode mode = NetMode::NORMAL);
bool Run(RunMetadata *run_metadata = nullptr) override; bool Run(RunMetadata *run_metadata = nullptr) override;
...@@ -49,11 +49,11 @@ class SimpleNet : public NetBase { ...@@ -49,11 +49,11 @@ class SimpleNet : public NetBase {
unique_ptr<NetBase> CreateNet(const NetDef &net_def, unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode = OpMode::NORMAL); const NetMode mode = NetMode::NORMAL);
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def, unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode = OpMode::NORMAL); const NetMode mode = NetMode::NORMAL);
} // namespace mace } // namespace mace
......
...@@ -50,15 +50,15 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry); ...@@ -50,15 +50,15 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def, unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode) { const NetMode mode) {
OperatorRegistry *registry = gDeviceTypeRegistry()->at(type); OperatorRegistry *registry = gDeviceTypeRegistry()->at(type);
const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def, const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"T", "T",
static_cast<int>(DT_FLOAT)); static_cast<int>(DT_FLOAT));
const int op_mode_i= ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def, const int op_mode_i= ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"mode", "mode",
static_cast<int>(OpMode::NORMAL)); static_cast<int>(NetMode::NORMAL));
const OpMode op_mode = static_cast<OpMode>(op_mode_i); const NetMode op_mode = static_cast<NetMode>(op_mode_i);
if (op_mode == mode) { if (op_mode == mode) {
return registry->Create(OpKeyBuilder(operator_def.type().data()) return registry->Create(OpKeyBuilder(operator_def.type().data())
.TypeConstraint("T", static_cast<DataType>(dtype)) .TypeConstraint("T", static_cast<DataType>(dtype))
......
...@@ -196,7 +196,7 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry, ...@@ -196,7 +196,7 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry,
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def, unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws, Workspace *ws,
DeviceType type, DeviceType type,
const OpMode mode); const NetMode mode);
} // namespace mace } // namespace mace
......
...@@ -101,7 +101,7 @@ int main(int argc, char **argv) { ...@@ -101,7 +101,7 @@ int main(int argc, char **argv) {
} }
// Init model // Init model
auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT); auto net = CreateNet(net_def, &ws, device_type, NetMode::INIT);
net->Run(); net->Run();
// run model // run model
......
...@@ -17,7 +17,7 @@ TEST(CoreTest, INIT_MODE) { ...@@ -17,7 +17,7 @@ TEST(CoreTest, INIT_MODE) {
.Input("Input") .Input("Input")
.Output("B2IOutput") .Output("B2IOutput")
.AddIntArg("buffer_type", kernels::BufferType::FILTER) .AddIntArg("buffer_type", kernels::BufferType::FILTER)
.AddIntArg("mode", static_cast<int>(OpMode::INIT)) .AddIntArg("mode", static_cast<int>(NetMode::INIT))
.Finalize(&op_defs[op_defs.size()-1]); .Finalize(&op_defs[op_defs.size()-1]);
Tensor *input = Tensor *input =
...@@ -40,7 +40,7 @@ TEST(CoreTest, INIT_MODE) { ...@@ -40,7 +40,7 @@ TEST(CoreTest, INIT_MODE) {
for (auto &op_def : op_defs) { for (auto &op_def : op_defs) {
net_def.add_op()->CopyFrom(op_def); net_def.add_op()->CopyFrom(op_def);
} }
auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT); auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, NetMode::INIT);
net->Run(); net->Run();
EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr); EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr);
......
...@@ -2,7 +2,7 @@ syntax = "proto2"; ...@@ -2,7 +2,7 @@ syntax = "proto2";
package mace; package mace;
enum OpMode { enum NetMode {
INIT = 0; INIT = 0;
NORMAL = 1; NORMAL = 1;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册