提交 062fa1e7 编写于 作者: L liuqi

Rename opmode to netmode.

上级 336bdbed
......@@ -18,7 +18,7 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type,
const OpMode mode)
const NetMode mode)
: NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) {
......@@ -93,7 +93,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws,
DeviceType type,
const OpMode mode) {
const NetMode mode) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type, mode);
}
......@@ -101,7 +101,7 @@ unique_ptr<NetBase> CreateNet(const NetDef &net_def,
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type,
const OpMode mode) {
const NetMode mode) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type, mode));
return net;
}
......
......@@ -35,7 +35,7 @@ class SimpleNet : public NetBase {
SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type,
const OpMode mode = OpMode::NORMAL);
const NetMode mode = NetMode::NORMAL);
bool Run(RunMetadata *run_metadata = nullptr) override;
......@@ -49,11 +49,11 @@ class SimpleNet : public NetBase {
unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws,
DeviceType type,
const OpMode mode = OpMode::NORMAL);
const NetMode mode = NetMode::NORMAL);
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type,
const OpMode mode = OpMode::NORMAL);
const NetMode mode = NetMode::NORMAL);
} // namespace mace
......
......@@ -50,15 +50,15 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
DeviceType type,
const OpMode mode) {
const NetMode mode) {
OperatorRegistry *registry = gDeviceTypeRegistry()->at(type);
const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"T",
static_cast<int>(DT_FLOAT));
const int op_mode_i= ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"mode",
static_cast<int>(OpMode::NORMAL));
const OpMode op_mode = static_cast<OpMode>(op_mode_i);
static_cast<int>(NetMode::NORMAL));
const NetMode op_mode = static_cast<NetMode>(op_mode_i);
if (op_mode == mode) {
return registry->Create(OpKeyBuilder(operator_def.type().data())
.TypeConstraint("T", static_cast<DataType>(dtype))
......
......@@ -196,7 +196,7 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry,
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
DeviceType type,
const OpMode mode);
const NetMode mode);
} // namespace mace
......
......@@ -101,7 +101,7 @@ int main(int argc, char **argv) {
}
// Init model
auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT);
auto net = CreateNet(net_def, &ws, device_type, NetMode::INIT);
net->Run();
// run model
......
......@@ -17,7 +17,7 @@ TEST(CoreTest, INIT_MODE) {
.Input("Input")
.Output("B2IOutput")
.AddIntArg("buffer_type", kernels::BufferType::FILTER)
.AddIntArg("mode", static_cast<int>(OpMode::INIT))
.AddIntArg("mode", static_cast<int>(NetMode::INIT))
.Finalize(&op_defs[op_defs.size()-1]);
Tensor *input =
......@@ -40,7 +40,7 @@ TEST(CoreTest, INIT_MODE) {
for (auto &op_def : op_defs) {
net_def.add_op()->CopyFrom(op_def);
}
auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT);
auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, NetMode::INIT);
net->Run();
EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr);
......
......@@ -2,7 +2,7 @@ syntax = "proto2";
package mace;
enum OpMode {
enum NetMode {
INIT = 0;
NORMAL = 1;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册