提交 1cb7cde5 编写于 作者: Y yejianwu

update engine run api in mace.h

上级 ee725558
...@@ -496,9 +496,11 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type): ...@@ -496,9 +496,11 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type):
net_ = std::move(CreateNet(*net_def, ws_.get(), device_type)); net_ = std::move(CreateNet(*net_def, ws_.get(), device_type));
} }
MaceEngine::~MaceEngine(){} MaceEngine::~MaceEngine(){}
const float *MaceEngine::Run(const float *input, bool MaceEngine::Run(const float *input,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
std::vector<int64_t> &output_shape) { float *output) {
MACE_CHECK(output != nullptr, "output ptr cannot be NULL");
Tensor *input_tensor = Tensor *input_tensor =
ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT); ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT);
input_tensor->Resize(input_shape); input_tensor->Resize(input_shape);
...@@ -511,16 +513,18 @@ const float *MaceEngine::Run(const float *input, ...@@ -511,16 +513,18 @@ const float *MaceEngine::Run(const float *input,
LOG(FATAL) << "Net run failed"; LOG(FATAL) << "Net run failed";
} }
// save output // save output
const Tensor *output = ws_->GetTensor("mace_output_node:0"); const Tensor *output_tensor = ws_->GetTensor("mace_output_node:0");
if (output != nullptr) { if (output_tensor != nullptr) {
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output_tensor);
auto shape = output->shape(); auto shape = output_tensor->shape();
output_shape.resize(shape.size()); int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::copy(shape.begin(), shape.end(), output_shape.begin()); std::multiplies<int64_t>());
return output->data<float>(); std::memcpy(output, output_tensor->data<float>(),
output_size * sizeof(float));
return true;
} else { } else {
return nullptr; return false;
} }
} }
......
...@@ -307,9 +307,9 @@ class MaceEngine { ...@@ -307,9 +307,9 @@ class MaceEngine {
public: public:
explicit MaceEngine(const NetDef *net_def, DeviceType device_type); explicit MaceEngine(const NetDef *net_def, DeviceType device_type);
~MaceEngine(); ~MaceEngine();
const float *Run(const float *input, bool Run(const float *input,
const std::vector<int64_t> &input_shape, const std::vector<int64_t> &input_shape,
std::vector<int64_t> &output_shape); float *output);
private: private:
DeviceType device_type_; DeviceType device_type_;
std::unique_ptr<Workspace> ws_; std::unique_ptr<Workspace> ws_;
......
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
* mace_run --model=mobi_mace.pb \ * mace_run --model=mobi_mace.pb \
* --input=input_node \ * --input=input_node \
* --output=MobilenetV1/Logits/conv2d/convolution \ * --output=MobilenetV1/Logits/conv2d/convolution \
* --input_shape=1,3,224,224 \ * --input_shape=1,224,224,3 \
* --output_shape=1,224,224,2 \
* --input_file=input_data \ * --input_file=input_data \
* --output_file=mace.out \ * --output_file=mace.out \
* --device=NEON * --device=NEON
...@@ -60,6 +61,7 @@ int main(int argc, char **argv) { ...@@ -60,6 +61,7 @@ int main(int argc, char **argv) {
string input_node; string input_node;
string output_node; string output_node;
string input_shape; string input_shape;
string output_shape;
string input_file; string input_file;
string output_file; string output_file;
string device; string device;
...@@ -70,6 +72,7 @@ int main(int argc, char **argv) { ...@@ -70,6 +72,7 @@ int main(int argc, char **argv) {
Flag("input", &input_node, "input node"), Flag("input", &input_node, "input node"),
Flag("output", &output_node, "output node"), Flag("output", &output_node, "output node"),
Flag("input_shape", &input_shape, "input shape, separated by comma"), Flag("input_shape", &input_shape, "input shape, separated by comma"),
Flag("output_shape", &output_shape, "output shape, separated by comma"),
Flag("input_file", &input_file, "input file name"), Flag("input_file", &input_file, "input file name"),
Flag("output_file", &output_file, "output file name"), Flag("output_file", &output_file, "output file name"),
Flag("device", &device, "CPU/NEON"), Flag("device", &device, "CPU/NEON"),
...@@ -90,13 +93,16 @@ int main(int argc, char **argv) { ...@@ -90,13 +93,16 @@ int main(int argc, char **argv) {
<< "input: " << input_node << std::endl << "input: " << input_node << std::endl
<< "output: " << output_node << std::endl << "output: " << output_node << std::endl
<< "input_shape: " << input_shape << std::endl << "input_shape: " << input_shape << std::endl
<< "output_shape: " << output_shape << std::endl
<< "input_file: " << input_file << std::endl << "input_file: " << input_file << std::endl
<< "output_file: " << output_file << std::endl << "output_file: " << output_file << std::endl
<< "device: " << device << std::endl << "device: " << device << std::endl
<< "round: " << round << std::endl; << "round: " << round << std::endl;
vector<int64_t> shape; vector<int64_t> input_shape_vec;
ParseShape(input_shape, &shape); vector<int64_t> output_shape_vec;
ParseShape(input_shape, &input_shape_vec);
ParseShape(output_shape, &output_shape_vec);
// load model // load model
int64_t t0 = utils::NowMicros(); int64_t t0 = utils::NowMicros();
...@@ -107,8 +113,12 @@ int main(int argc, char **argv) { ...@@ -107,8 +113,12 @@ int main(int argc, char **argv) {
DeviceType device_type = ParseDeviceType(device); DeviceType device_type = ParseDeviceType(device);
VLOG(1) << "Device Type" << device_type; VLOG(1) << "Device Type" << device_type;
int64_t input_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>()); int64_t input_size = std::accumulate(input_shape_vec.begin(),
input_shape_vec.end(), 1, std::multiplies<int64_t>());
int64_t output_size = std::accumulate(output_shape_vec.begin(),
output_shape_vec.end(), 1, std::multiplies<int64_t>());
std::unique_ptr<float[]> input_data(new float[input_size]); std::unique_ptr<float[]> input_data(new float[input_size]);
std::unique_ptr<float[]> output_data(new float[output_size]);
// load input // load input
ifstream in_file(input_file, ios::in | ios::binary); ifstream in_file(input_file, ios::in | ios::binary);
...@@ -126,10 +136,9 @@ int main(int argc, char **argv) { ...@@ -126,10 +136,9 @@ int main(int argc, char **argv) {
LOG(INFO) << "Total init duration: " << init_micros << " us"; LOG(INFO) << "Total init duration: " << init_micros << " us";
std::vector<int64_t> output_shape;
VLOG(0) << "Warm up"; VLOG(0) << "Warm up";
t0 = utils::NowMicros(); t0 = utils::NowMicros();
engine.Run(input_data.get(), shape, output_shape); engine.Run(input_data.get(), input_shape_vec, output_data.get());
t1 = utils::NowMicros(); t1 = utils::NowMicros();
LOG(INFO) << "1st warm up run duration: " << t1 - t0 << " us"; LOG(INFO) << "1st warm up run duration: " << t1 - t0 << " us";
...@@ -137,26 +146,18 @@ int main(int argc, char **argv) { ...@@ -137,26 +146,18 @@ int main(int argc, char **argv) {
VLOG(0) << "Run model"; VLOG(0) << "Run model";
t0 = utils::NowMicros(); t0 = utils::NowMicros();
for (int i = 0; i < round; ++i) { for (int i = 0; i < round; ++i) {
engine.Run(input_data.get(), shape, output_shape); engine.Run(input_data.get(), input_shape_vec, output_data.get());
} }
t1 = utils::NowMicros(); t1 = utils::NowMicros();
LOG(INFO) << "Avg duration: " << (t1 - t0) / round << " us"; LOG(INFO) << "Avg duration: " << (t1 - t0) / round << " us";
} }
const float *output = engine.Run(input_data.get(), shape, output_shape); MACE_CHECK(engine.Run(input_data.get(), input_shape_vec, output_data.get()));
if (output != nullptr) { if (output_data != nullptr) {
ofstream out_file(output_file, ios::binary); ofstream out_file(output_file, ios::binary);
int64_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>()); out_file.write((const char *) (output_data.get()),
out_file.write((const char *) (output),
output_size * sizeof(float)); output_size * sizeof(float));
out_file.flush(); out_file.flush();
out_file.close(); out_file.close();
stringstream ss;
ss << "Output shape: [";
for (auto i : output_shape) {
ss << i << ", ";
}
ss << "]";
VLOG(0) << ss.str();
} }
} }
...@@ -68,6 +68,7 @@ build_and_run() ...@@ -68,6 +68,7 @@ build_and_run()
--input=mace_input_node \ --input=mace_input_node \
--output=mace_output_node \ --output=mace_output_node \
--input_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},3"\ --input_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},3"\
--output_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},2"\
--input_file=${PHONE_DATA_DIR}/${INPUT_FILE_NAME} \ --input_file=${PHONE_DATA_DIR}/${INPUT_FILE_NAME} \
--output_file=${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} \ --output_file=${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} \
--device=OPENCL \ --device=OPENCL \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册