提交 0eecab94 编写于 作者: 刘琦

Merge branch 'update_engine_run' into 'master'

Update engine run

See merge request !185
......@@ -494,9 +494,11 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type):
net_ = std::move(CreateNet(*net_def, ws_.get(), device_type));
}
MaceEngine::~MaceEngine(){}
const float *MaceEngine::Run(const float *input,
const std::vector<index_t> &input_shape,
std::vector<int64_t> &output_shape) {
bool MaceEngine::Run(const float *input,
const std::vector<index_t> &input_shape,
float *output) {
MACE_CHECK(output != nullptr, "output ptr cannot be NULL");
Tensor *input_tensor =
ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT);
input_tensor->Resize(input_shape);
......@@ -509,16 +511,18 @@ const float *MaceEngine::Run(const float *input,
LOG(FATAL) << "Net run failed";
}
// save output
const Tensor *output = ws_->GetTensor("mace_output_node:0");
if (output != nullptr) {
Tensor::MappingGuard output_guard(output);
auto shape = output->shape();
output_shape.resize(shape.size());
std::copy(shape.begin(), shape.end(), output_shape.begin());
return output->data<float>();
const Tensor *output_tensor = ws_->GetTensor("mace_output_node:0");
if (output_tensor != nullptr) {
Tensor::MappingGuard output_guard(output_tensor);
auto shape = output_tensor->shape();
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
std::memcpy(output, output_tensor->data<float>(),
output_size * sizeof(float));
return true;
} else {
return nullptr;
return false;
}
}
......
......@@ -307,9 +307,9 @@ class MaceEngine {
public:
explicit MaceEngine(const NetDef *net_def, DeviceType device_type);
~MaceEngine();
const float *Run(const float *input,
const std::vector<int64_t> &input_shape,
std::vector<int64_t> &output_shape);
bool Run(const float *input,
const std::vector<int64_t> &input_shape,
float *output);
private:
DeviceType device_type_;
std::unique_ptr<Workspace> ws_;
......
......@@ -7,7 +7,8 @@
* mace_run --model=mobi_mace.pb \
* --input=input_node \
* --output=MobilenetV1/Logits/conv2d/convolution \
* --input_shape=1,3,224,224 \
* --input_shape=1,224,224,3 \
* --output_shape=1,224,224,2 \
* --input_file=input_data \
* --output_file=mace.out \
* --device=NEON
......@@ -108,6 +109,7 @@ int main(int argc, char **argv) {
string input_node;
string output_node;
string input_shape;
string output_shape;
string input_file;
string output_file;
string device;
......@@ -119,6 +121,7 @@ int main(int argc, char **argv) {
Flag("input", &input_node, "input node"),
Flag("output", &output_node, "output node"),
Flag("input_shape", &input_shape, "input shape, separated by comma"),
Flag("output_shape", &output_shape, "output shape, separated by comma"),
Flag("input_file", &input_file, "input file name"),
Flag("output_file", &output_file, "output file name"),
Flag("device", &device, "CPU/NEON"),
......@@ -141,13 +144,16 @@ int main(int argc, char **argv) {
<< "input: " << input_node << std::endl
<< "output: " << output_node << std::endl
<< "input_shape: " << input_shape << std::endl
<< "output_shape: " << output_shape << std::endl
<< "input_file: " << input_file << std::endl
<< "output_file: " << output_file << std::endl
<< "device: " << device << std::endl
<< "round: " << round << std::endl;
vector<int64_t> shape;
ParseShape(input_shape, &shape);
vector<int64_t> input_shape_vec;
vector<int64_t> output_shape_vec;
ParseShape(input_shape, &input_shape_vec);
ParseShape(output_shape, &output_shape_vec);
// load model
int64_t t0 = utils::NowMicros();
......@@ -158,9 +164,12 @@ int main(int argc, char **argv) {
DeviceType device_type = ParseDeviceType(device);
VLOG(1) << "Device Type" << device_type;
int64_t input_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>());
int64_t input_size = std::accumulate(input_shape_vec.begin(),
input_shape_vec.end(), 1, std::multiplies<int64_t>());
int64_t output_size = std::accumulate(output_shape_vec.begin(),
output_shape_vec.end(), 1, std::multiplies<int64_t>());
std::unique_ptr<float[]> input_data(new float[input_size]);
std::unique_ptr<float[]> output_data(new float[output_size]);
// load input
ifstream in_file(input_file, ios::in | ios::binary);
......@@ -178,10 +187,9 @@ int main(int argc, char **argv) {
LOG(INFO) << "Total init duration: " << init_micros << " us";
std::vector<int64_t> output_shape;
VLOG(0) << "Warm up";
t0 = utils::NowMicros();
engine.Run(input_data.get(), shape, output_shape);
engine.Run(input_data.get(), input_shape_vec, output_data.get());
t1 = utils::NowMicros();
LOG(INFO) << "1st warm up run duration: " << t1 - t0 << " us";
......@@ -190,7 +198,7 @@ int main(int argc, char **argv) {
t0 = utils::NowMicros();
struct mallinfo prev = mallinfo();
for (int i = 0; i < round; ++i) {
engine.Run(input_data.get(), shape, output_shape);
engine.Run(input_data.get(), input_shape_vec, output_data.get());
if (malloc_check_cycle >= 1 && i % malloc_check_cycle == 0) {
LOG(INFO) << "=== check malloc info change #" << i << " ===";
prev = LogMallinfoChange(prev);
......@@ -200,21 +208,12 @@ int main(int argc, char **argv) {
LOG(INFO) << "Avg duration: " << (t1 - t0) / round << " us";
}
const float *output = engine.Run(input_data.get(), shape, output_shape);
if (output != nullptr) {
MACE_CHECK(engine.Run(input_data.get(), input_shape_vec, output_data.get()));
if (output_data != nullptr) {
ofstream out_file(output_file, ios::binary);
int64_t output_size =
std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int64_t>());
out_file.write((const char *)(output), output_size * sizeof(float));
out_file.write((const char *) (output_data.get()),
output_size * sizeof(float));
out_file.flush();
out_file.close();
stringstream ss;
ss << "Output shape: [";
for (auto i : output_shape) {
ss << i << ", ";
}
ss << "]";
VLOG(0) << ss.str();
}
}
......@@ -69,6 +69,7 @@ build_and_run()
--input=mace_input_node \
--output=mace_output_node \
--input_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},3"\
--output_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},2"\
--input_file=${PHONE_DATA_DIR}/${INPUT_FILE_NAME} \
--output_file=${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} \
--device=OPENCL \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册