提交 336bdbed 编写于 作者: L liuqi

Add OpMode to support some op to run once at the beginning.

上级 5c1264b3
...@@ -17,7 +17,8 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def, ...@@ -17,7 +17,8 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type) DeviceType type,
const OpMode mode)
: NetBase(net_def, ws, type), device_type_(type){ : NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name(); VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
...@@ -26,7 +27,7 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def, ...@@ -26,7 +27,7 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
<< operator_def.type(); << operator_def.type();
std::unique_ptr<OperatorBase> op{nullptr}; std::unique_ptr<OperatorBase> op{nullptr};
OperatorDef temp_def(operator_def); OperatorDef temp_def(operator_def);
op = CreateOperator(temp_def, ws, type); op = CreateOperator(temp_def, ws, type, mode);
if (op) { if (op) {
operators_.emplace_back(std::move(op)); operators_.emplace_back(std::move(op));
} }
...@@ -91,15 +92,17 @@ bool SimpleNet::Run(RunMetadata *run_metadata) { ...@@ -91,15 +92,17 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
unique_ptr<NetBase> CreateNet(const NetDef &net_def, unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws, Workspace *ws,
DeviceType type) { DeviceType type,
const OpMode mode) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def)); std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type); return CreateNet(tmp_net_def, ws, type, mode);
} }
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def, unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type) { DeviceType type,
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type)); const OpMode mode) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type, mode));
return net; return net;
} }
......
...@@ -34,7 +34,8 @@ class SimpleNet : public NetBase { ...@@ -34,7 +34,8 @@ class SimpleNet : public NetBase {
public: public:
SimpleNet(const std::shared_ptr<const NetDef> &net_def, SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type); DeviceType type,
const OpMode mode = OpMode::NORMAL);
bool Run(RunMetadata *run_metadata = nullptr) override; bool Run(RunMetadata *run_metadata = nullptr) override;
...@@ -47,10 +48,12 @@ class SimpleNet : public NetBase { ...@@ -47,10 +48,12 @@ class SimpleNet : public NetBase {
unique_ptr<NetBase> CreateNet(const NetDef &net_def, unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws, Workspace *ws,
DeviceType type); DeviceType type,
const OpMode mode = OpMode::NORMAL);
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def, unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws, Workspace *ws,
DeviceType type); DeviceType type,
const OpMode mode = OpMode::NORMAL);
} // namespace mace } // namespace mace
......
...@@ -49,16 +49,25 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry); ...@@ -49,16 +49,25 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def, unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws, Workspace *ws,
DeviceType type) { DeviceType type,
const OpMode mode) {
OperatorRegistry *registry = gDeviceTypeRegistry()->at(type); OperatorRegistry *registry = gDeviceTypeRegistry()->at(type);
const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def, const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"T", "T",
static_cast<int>(DT_FLOAT)); static_cast<int>(DT_FLOAT));
return registry->Create(OpKeyBuilder(operator_def.type().data()) const int op_mode_i= ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
.TypeConstraint("T", static_cast<DataType>(dtype)) "mode",
.Build(), static_cast<int>(OpMode::NORMAL));
operator_def, const OpMode op_mode = static_cast<OpMode>(op_mode_i);
ws); if (op_mode == mode) {
return registry->Create(OpKeyBuilder(operator_def.type().data())
.TypeConstraint("T", static_cast<DataType>(dtype))
.Build(),
operator_def,
ws);
} else {
return nullptr;
}
} }
OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws) OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws)
......
...@@ -195,7 +195,8 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry, ...@@ -195,7 +195,8 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry,
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def, unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws, Workspace *ws,
DeviceType type); DeviceType type,
const OpMode mode);
} // namespace mace } // namespace mace
......
...@@ -100,9 +100,12 @@ int main(int argc, char **argv) { ...@@ -100,9 +100,12 @@ int main(int argc, char **argv) {
in_file.close(); in_file.close();
} }
// Init model
auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT);
net->Run();
// run model // run model
auto net = CreateNet(net_def, &ws, device_type); net = CreateNet(net_def, &ws, device_type);
VLOG(0) << "warm up"; VLOG(0) << "warm up";
// warm up // warm up
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/ops_test_util.h"
namespace mace {
TEST(CoreTest, INIT_MODE) {
std::vector<OperatorDef> op_defs;
Workspace ws;
op_defs.emplace_back(OperatorDef());
OpDefBuilder("BufferToImage", "BufferToImageTest")
.Input("Input")
.Output("B2IOutput")
.AddIntArg("buffer_type", kernels::BufferType::FILTER)
.AddIntArg("mode", static_cast<int>(OpMode::INIT))
.Finalize(&op_defs[op_defs.size()-1]);
Tensor *input =
ws.CreateTensor("Input", GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum<float>::v());
input->Resize({1, 3, 3, 3});
{
Tensor::MappingGuard input_mapper(input);
float *input_data = input->mutable_data<float>();
std::fill(input_data, input_data + input->size(), 1);
}
op_defs.emplace_back(OperatorDef());
OpDefBuilder("ImageToBuffer", "ImageToBufferTest")
.Input("B2IOutput")
.Output("Output")
.AddIntArg("buffer_type", kernels::BufferType::FILTER)
.Finalize(&op_defs[op_defs.size()-1]);
NetDef net_def;
for (auto &op_def : op_defs) {
net_def.add_op()->CopyFrom(op_def);
}
auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT);
net->Run();
EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr);
EXPECT_TRUE(ws.GetTensor("Output") == nullptr);
net = CreateNet(net_def, &ws, DeviceType::OPENCL);
net->Run();
EXPECT_TRUE(ws.GetTensor("Output") != nullptr);
ExpectTensorNear<float>(*ws.GetTensor("Input"), *ws.GetTensor("Output"), 1e-5);
}
} // namespace mace
...@@ -2,6 +2,11 @@ syntax = "proto2"; ...@@ -2,6 +2,11 @@ syntax = "proto2";
package mace; package mace;
enum OpMode {
INIT = 0;
NORMAL = 1;
}
enum DeviceType { enum DeviceType {
CPU = 0; // In default, we will use CPU. CPU = 0; // In default, we will use CPU.
NEON = 1; NEON = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册