diff --git a/mace/core/net.cc b/mace/core/net.cc index c0c536995bcd9d3cf634422b2099292a41aa186c..93922891629108a74f3652e890e514ed4155e44a 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -17,7 +17,8 @@ NetBase::NetBase(const std::shared_ptr &net_def, SimpleNet::SimpleNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type) + DeviceType type, + const OpMode mode) : NetBase(net_def, ws, type), device_type_(type){ VLOG(1) << "Constructing SimpleNet " << net_def->name(); for (int idx = 0; idx < net_def->op_size(); ++idx) { @@ -26,7 +27,7 @@ SimpleNet::SimpleNet(const std::shared_ptr &net_def, << operator_def.type(); std::unique_ptr op{nullptr}; OperatorDef temp_def(operator_def); - op = CreateOperator(temp_def, ws, type); + op = CreateOperator(temp_def, ws, type, mode); if (op) { operators_.emplace_back(std::move(op)); } @@ -91,15 +92,17 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { unique_ptr CreateNet(const NetDef &net_def, Workspace *ws, - DeviceType type) { + DeviceType type, + const OpMode mode) { std::shared_ptr tmp_net_def(new NetDef(net_def)); - return CreateNet(tmp_net_def, ws, type); + return CreateNet(tmp_net_def, ws, type, mode); } unique_ptr CreateNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type) { - unique_ptr net(new SimpleNet(net_def, ws, type)); + DeviceType type, + const OpMode mode) { + unique_ptr net(new SimpleNet(net_def, ws, type, mode)); return net; } diff --git a/mace/core/net.h b/mace/core/net.h index 013ca715cafce82242ad148d8ff12c7df8fd9fb4..73aad694ef4149aad97f4d1f943e43ecf33d66e3 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -34,7 +34,8 @@ class SimpleNet : public NetBase { public: SimpleNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode = OpMode::NORMAL); bool Run(RunMetadata *run_metadata = nullptr) override; @@ -47,10 +48,12 @@ class SimpleNet : public NetBase { unique_ptr CreateNet(const NetDef &net_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode = OpMode::NORMAL); unique_ptr CreateNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode = OpMode::NORMAL); } // namespace mace diff --git a/mace/core/operator.cc b/mace/core/operator.cc index e2e8936b62b46e164e1508ae08e2f998f8e12b32..b29efc59a484abe7f75d2ee7dac009156483c47a 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -49,16 +49,25 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry); unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, - DeviceType type) { + DeviceType type, + const OpMode mode) { OperatorRegistry *registry = gDeviceTypeRegistry()->at(type); const int dtype = ArgumentHelper::GetSingleArgument(operator_def, "T", static_cast(DT_FLOAT)); - return registry->Create(OpKeyBuilder(operator_def.type().data()) - .TypeConstraint("T", static_cast(dtype)) - .Build(), - operator_def, - ws); + const int op_mode_i= ArgumentHelper::GetSingleArgument(operator_def, + "mode", + static_cast(OpMode::NORMAL)); + const OpMode op_mode = static_cast(op_mode_i); + if (op_mode == mode) { + return registry->Create(OpKeyBuilder(operator_def.type().data()) + .TypeConstraint("T", static_cast(dtype)) + .Build(), + operator_def, + ws); + } else { + return nullptr; + } } OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws) diff --git a/mace/core/operator.h b/mace/core/operator.h index 6ee4a9c4d2c637fd7b60c070355c02e155db7a01..11754f65dcb54876c4cce27155c5244bdb1c5da3 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -195,7 +195,8 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry, unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode); } // namespace mace diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index c13d6a95529afda817e34c92aa43799c8e55a959..61dd160d14519d59510865971c190ba6ffd2160a 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -100,9 +100,12 @@ int main(int argc, char **argv) { in_file.close(); } + // Init model + auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT); + net->Run(); // run model - auto net = CreateNet(net_def, &ws, device_type); + net = CreateNet(net_def, &ws, device_type); VLOG(0) << "warm up"; // warm up diff --git a/mace/ops/core_test.cc b/mace/ops/core_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27efc095fa02df6e54ce073f598f41cf9761500f --- /dev/null +++ b/mace/ops/core_test.cc @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/ops_test_util.h" + +namespace mace { + +TEST(CoreTest, INIT_MODE) { + + std::vector op_defs; + + Workspace ws; + + op_defs.emplace_back(OperatorDef()); + OpDefBuilder("BufferToImage", "BufferToImageTest") + .Input("Input") + .Output("B2IOutput") + .AddIntArg("buffer_type", kernels::BufferType::FILTER) + .AddIntArg("mode", static_cast(OpMode::INIT)) + .Finalize(&op_defs[op_defs.size()-1]); + + Tensor *input = + ws.CreateTensor("Input", GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum::v()); + input->Resize({1, 3, 3, 3}); + { + Tensor::MappingGuard input_mapper(input); + float *input_data = input->mutable_data(); + std::fill(input_data, input_data + input->size(), 1); + } + + op_defs.emplace_back(OperatorDef()); + OpDefBuilder("ImageToBuffer", "ImageToBufferTest") + .Input("B2IOutput") + .Output("Output") + .AddIntArg("buffer_type", kernels::BufferType::FILTER) + .Finalize(&op_defs[op_defs.size()-1]); + + NetDef net_def; + for (auto &op_def : op_defs) { + net_def.add_op()->CopyFrom(op_def); + } + auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT); + net->Run(); + + EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr); + EXPECT_TRUE(ws.GetTensor("Output") == nullptr); + + net = CreateNet(net_def, &ws, DeviceType::OPENCL); + net->Run(); + EXPECT_TRUE(ws.GetTensor("Output") != nullptr); + + ExpectTensorNear(*ws.GetTensor("Input"), *ws.GetTensor("Output"), 1e-5); +} + +} // namespace mace diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 119e1fed79a7cad1374cdb3891745ec2c83716bb..ab6be22b90cc563472dd3924ddc637a8909a3f2a 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -2,6 +2,11 @@ syntax = "proto2"; package mace; +enum OpMode { + INIT = 0; + NORMAL = 1; +} + enum DeviceType { CPU = 0; // In default, we will use CPU. NEON = 1;