提交 336bdbed 编写于 作者: L liuqi

Add OpMode to support some op to run once at the beginning.

上级 5c1264b3
......@@ -17,7 +17,8 @@ NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type)
DeviceType type,
const OpMode mode)
: NetBase(net_def, ws, type), device_type_(type){
VLOG(1) << "Constructing SimpleNet " << net_def->name();
for (int idx = 0; idx < net_def->op_size(); ++idx) {
......@@ -26,7 +27,7 @@ SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
<< operator_def.type();
std::unique_ptr<OperatorBase> op{nullptr};
OperatorDef temp_def(operator_def);
op = CreateOperator(temp_def, ws, type);
op = CreateOperator(temp_def, ws, type, mode);
if (op) {
operators_.emplace_back(std::move(op));
}
......@@ -91,15 +92,17 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws,
DeviceType type) {
DeviceType type,
const OpMode mode) {
std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
return CreateNet(tmp_net_def, ws, type);
return CreateNet(tmp_net_def, ws, type, mode);
}
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type));
DeviceType type,
const OpMode mode) {
unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type, mode));
return net;
}
......
......@@ -34,7 +34,8 @@ class SimpleNet : public NetBase {
public:
SimpleNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type);
DeviceType type,
const OpMode mode = OpMode::NORMAL);
bool Run(RunMetadata *run_metadata = nullptr) override;
......@@ -47,10 +48,12 @@ class SimpleNet : public NetBase {
unique_ptr<NetBase> CreateNet(const NetDef &net_def,
Workspace *ws,
DeviceType type);
DeviceType type,
const OpMode mode = OpMode::NORMAL);
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
Workspace *ws,
DeviceType type);
DeviceType type,
const OpMode mode = OpMode::NORMAL);
} // namespace mace
......
......@@ -49,16 +49,25 @@ MACE_REGISTER_DEVICE_TYPE(DeviceType::OPENCL, OPENCLOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
DeviceType type) {
DeviceType type,
const OpMode mode) {
OperatorRegistry *registry = gDeviceTypeRegistry()->at(type);
const int dtype = ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"T",
static_cast<int>(DT_FLOAT));
return registry->Create(OpKeyBuilder(operator_def.type().data())
.TypeConstraint("T", static_cast<DataType>(dtype))
.Build(),
operator_def,
ws);
const int op_mode_i= ArgumentHelper::GetSingleArgument<OperatorDef, int>(operator_def,
"mode",
static_cast<int>(OpMode::NORMAL));
const OpMode op_mode = static_cast<OpMode>(op_mode_i);
if (op_mode == mode) {
return registry->Create(OpKeyBuilder(operator_def.type().data())
.TypeConstraint("T", static_cast<DataType>(dtype))
.Build(),
operator_def,
ws);
} else {
return nullptr;
}
}
OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws)
......
......@@ -195,7 +195,8 @@ MACE_DECLARE_REGISTRY(OPENCLOperatorRegistry,
unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
DeviceType type);
DeviceType type,
const OpMode mode);
} // namespace mace
......
......@@ -100,9 +100,12 @@ int main(int argc, char **argv) {
in_file.close();
}
// Init model
auto net = CreateNet(net_def, &ws, device_type, OpMode::INIT);
net->Run();
// run model
auto net = CreateNet(net_def, &ws, device_type);
net = CreateNet(net_def, &ws, device_type);
VLOG(0) << "warm up";
// warm up
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/ops_test_util.h"
namespace mace {
TEST(CoreTest, INIT_MODE) {
std::vector<OperatorDef> op_defs;
Workspace ws;
op_defs.emplace_back(OperatorDef());
OpDefBuilder("BufferToImage", "BufferToImageTest")
.Input("Input")
.Output("B2IOutput")
.AddIntArg("buffer_type", kernels::BufferType::FILTER)
.AddIntArg("mode", static_cast<int>(OpMode::INIT))
.Finalize(&op_defs[op_defs.size()-1]);
Tensor *input =
ws.CreateTensor("Input", GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum<float>::v());
input->Resize({1, 3, 3, 3});
{
Tensor::MappingGuard input_mapper(input);
float *input_data = input->mutable_data<float>();
std::fill(input_data, input_data + input->size(), 1);
}
op_defs.emplace_back(OperatorDef());
OpDefBuilder("ImageToBuffer", "ImageToBufferTest")
.Input("B2IOutput")
.Output("Output")
.AddIntArg("buffer_type", kernels::BufferType::FILTER)
.Finalize(&op_defs[op_defs.size()-1]);
NetDef net_def;
for (auto &op_def : op_defs) {
net_def.add_op()->CopyFrom(op_def);
}
auto net = CreateNet(net_def, &ws, DeviceType::OPENCL, OpMode::INIT);
net->Run();
EXPECT_TRUE(ws.GetTensor("B2IOutput") != nullptr);
EXPECT_TRUE(ws.GetTensor("Output") == nullptr);
net = CreateNet(net_def, &ws, DeviceType::OPENCL);
net->Run();
EXPECT_TRUE(ws.GetTensor("Output") != nullptr);
ExpectTensorNear<float>(*ws.GetTensor("Input"), *ws.GetTensor("Output"), 1e-5);
}
} // namespace mace
......@@ -2,6 +2,11 @@ syntax = "proto2";
package mace;
enum OpMode {
INIT = 0;
NORMAL = 1;
}
enum DeviceType {
CPU = 0; // In default, we will use CPU.
NEON = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册