From 336bdbed77fdc8806e2d373fa8aae6a8a5346ad8 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 4 Dec 2017 19:56:31 +0800 Subject: [PATCH] Add OpMode to support some op to run once at the beginning. --- mace/core/net.cc | 15 ++++++----- mace/core/net.h | 9 ++++--- mace/core/operator.cc | 21 ++++++++++----- mace/core/operator.h | 3 ++- mace/examples/mace_run.cc | 5 +++- mace/ops/core_test.cc | 56 +++++++++++++++++++++++++++++++++++++++ mace/proto/mace.proto | 5 ++++ 7 files changed, 97 insertions(+), 17 deletions(-) create mode 100644 mace/ops/core_test.cc diff --git a/mace/core/net.cc b/mace/core/net.cc index c0c53699..93922891 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -17,7 +17,8 @@ NetBase::NetBase(const std::shared_ptr &net_def, SimpleNet::SimpleNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type) + DeviceType type, + const OpMode mode) : NetBase(net_def, ws, type), device_type_(type){ VLOG(1) << "Constructing SimpleNet " << net_def->name(); for (int idx = 0; idx < net_def->op_size(); ++idx) { @@ -26,7 +27,7 @@ SimpleNet::SimpleNet(const std::shared_ptr &net_def, << operator_def.type(); std::unique_ptr op{nullptr}; OperatorDef temp_def(operator_def); - op = CreateOperator(temp_def, ws, type); + op = CreateOperator(temp_def, ws, type, mode); if (op) { operators_.emplace_back(std::move(op)); } @@ -91,15 +92,17 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { unique_ptr CreateNet(const NetDef &net_def, Workspace *ws, - DeviceType type) { + DeviceType type, + const OpMode mode) { std::shared_ptr tmp_net_def(new NetDef(net_def)); - return CreateNet(tmp_net_def, ws, type); + return CreateNet(tmp_net_def, ws, type, mode); } unique_ptr CreateNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type) { - unique_ptr net(new SimpleNet(net_def, ws, type)); + DeviceType type, + const OpMode mode) { + unique_ptr net(new SimpleNet(net_def, ws, type, mode)); return net; } diff --git a/mace/core/net.h b/mace/core/net.h index 013ca715..73aad694 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -34,7 +34,8 @@ class SimpleNet : public NetBase { public: SimpleNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode = OpMode::NORMAL); bool Run(RunMetadata *run_metadata = nullptr) override; @@ -47,10 +48,12 @@ class SimpleNet : public NetBase { unique_ptr CreateNet(const NetDef &net_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode = OpMode::NORMAL); unique_ptr CreateNet(const std::shared_ptr &net_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode = OpMode::NORMAL); } // namespace mace diff --git a/mace/core/operator.cc b/mace/core/operator.cc index e2e8936b..b29efc59 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -49,16 +49,25 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry); unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, - DeviceType type) { + DeviceType type, + const OpMode mode) { OperatorRegistry *registry = gDeviceTypeRegistry()->at(type); const int dtype = ArgumentHelper::GetSingleArgument(operator_def, "T", static_cast(DT_FLOAT)); - return registry->Create(OpKeyBuilder(operator_def.type().data()) - .TypeConstraint("T", static_cast(dtype)) - .Build(), - operator_def, - ws); + const int op_mode_i= ArgumentHelper::GetSingleArgument(operator_def, + "mode", + static_cast(OpMode::NORMAL)); + const OpMode op_mode = static_cast(op_mode_i); + if (op_mode == mode) { + return registry->Create(OpKeyBuilder(operator_def.type().data()) + .TypeConstraint("T", static_cast(dtype)) + .Build(), + operator_def, + ws); + } else { + return nullptr; + } } OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws) diff --git a/mace/core/operator.h b/mace/core/operator.h index 6ee4a9c4..11754f65 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -195,7 +195,8 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry, unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, - DeviceType type); + DeviceType type, + const OpMode mode); } // namespace mace diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index c13d6a95..61dd160d 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -100,9 +100,12 @@ int main(int argc, char **argv) { in_file.close(); } + // Init model + auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT); + net->Run(); // run model - auto net = CreateNet(net_def, &ws, device_type); + net = CreateNet(net_def, &ws, device_type); VLOG(0) << "warm up"; // warm up diff --git a/mace/ops/core_test.cc b/mace/ops/core_test.cc new file mode 100644 index 00000000..27efc095 --- /dev/null +++ b/mace/ops/core_test.cc @@ -0,0 +1,56 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/ops_test_util.h" + +namespace mace { + +TEST(CoreTest, INIT_MODE) { + + std::vector op_defs; + + Workspace ws; + + op_defs.emplace_back(OperatorDef()); + OpDefBuilder("BufferToImage", "BufferToImageTest") + .Input("Input") + .Output("B2IOutput") + .AddIntArg("buffer_type", kernels::BufferType::FILTER) + .AddIntArg("mode", static_cast(OpMode::INIT)) + .Finalize(&op_defs[op_defs.size()-1]); + + Tensor *input = + ws.CreateTensor("Input", GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum::v()); + input->Resize({1, 3, 3, 3}); + { + Tensor::MappingGuard input_mapper(input); + float *input_data = input->mutable_data(); + std::fill(input_data, input_data + input->size(), 1); + } + + op_defs.emplace_back(OperatorDef()); + OpDefBuilder("ImageToBuffer", "ImageToBufferTest") + .Input("B2IOutput") + .Output("Output") + .AddIntArg("buffer_type", kernels::BufferType::FILTER) + .Finalize(&op_defs[op_defs.size()-1]); + + NetDef net_def; + for (auto &op_def : op_defs) { + net_def.add_op()->CopyFrom(op_def); + } + auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT); + net->Run(); + + EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr); + EXPECT_TRUE(ws.GetTensor("Output") == nullptr); + + net = CreateNet(net_def, &ws, DeviceType::OPENCL); + net->Run(); + EXPECT_TRUE(ws.GetTensor("Output") != nullptr); + + ExpectTensorNear(*ws.GetTensor("Input"), *ws.GetTensor("Output"), 1e-5); +} + +} // namespace mace diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 119e1fed..ab6be22b 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -2,6 +2,11 @@ syntax = "proto2"; package mace; +enum OpMode { + INIT = 0; + NORMAL = 1; +} + enum DeviceType { CPU = 0; // In default, we will use CPU. NEON = 1; -- GitLab