diff --git a/mace/core/mace.cc b/mace/core/mace.cc index a0e685694ea376a31eedeab7954672bff2cb4310..de6c93534e08bbdbcb71429a5dcddf69c2d50f11 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -494,9 +494,11 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type): net_ = std::move(CreateNet(*net_def, ws_.get(), device_type)); } MaceEngine::~MaceEngine(){} -const float *MaceEngine::Run(const float *input, - const std::vector &input_shape, - std::vector &output_shape) { +bool MaceEngine::Run(const float *input, + const std::vector &input_shape, + float *output) { + MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); + Tensor *input_tensor = ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT); input_tensor->Resize(input_shape); @@ -509,16 +511,18 @@ const float *MaceEngine::Run(const float *input, LOG(FATAL) << "Net run failed"; } // save output - const Tensor *output = ws_->GetTensor("mace_output_node:0"); - - if (output != nullptr) { - Tensor::MappingGuard output_guard(output); - auto shape = output->shape(); - output_shape.resize(shape.size()); - std::copy(shape.begin(), shape.end(), output_shape.begin()); - return output->data(); + const Tensor *output_tensor = ws_->GetTensor("mace_output_node:0"); + + if (output_tensor != nullptr) { + Tensor::MappingGuard output_guard(output_tensor); + auto shape = output_tensor->shape(); + int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + std::memcpy(output, output_tensor->data(), + output_size * sizeof(float)); + return true; } else { - return nullptr; + return false; } } diff --git a/mace/core/public/mace.h b/mace/core/public/mace.h index caf5b311b0dbbedd40797000362e7d0bd46776be..088d7f9cacde4e800e73f5ffd5cce91d18304e53 100644 --- a/mace/core/public/mace.h +++ b/mace/core/public/mace.h @@ -307,9 +307,9 @@ class MaceEngine { public: explicit MaceEngine(const NetDef *net_def, DeviceType device_type); ~MaceEngine(); - const float *Run(const float *input, - const std::vector &input_shape, - std::vector &output_shape); + bool Run(const float *input, + const std::vector &input_shape, + float *output); private: DeviceType device_type_; std::unique_ptr ws_; diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 3de72de9289a419b98f9122c0f196f2fdb101b75..cb8a97257836042a9cfea81f640971215d68d5ec 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -7,7 +7,8 @@ * mace_run --model=mobi_mace.pb \ * --input=input_node \ * --output=MobilenetV1/Logits/conv2d/convolution \ - * --input_shape=1,3,224,224 \ + * --input_shape=1,224,224,3 \ + * --output_shape=1,224,224,2 \ * --input_file=input_data \ * --output_file=mace.out \ * --device=NEON @@ -108,6 +109,7 @@ int main(int argc, char **argv) { string input_node; string output_node; string input_shape; + string output_shape; string input_file; string output_file; string device; @@ -119,6 +121,7 @@ int main(int argc, char **argv) { Flag("input", &input_node, "input node"), Flag("output", &output_node, "output node"), Flag("input_shape", &input_shape, "input shape, separated by comma"), + Flag("output_shape", &output_shape, "output shape, separated by comma"), Flag("input_file", &input_file, "input file name"), Flag("output_file", &output_file, "output file name"), Flag("device", &device, "CPU/NEON"), @@ -141,13 +144,16 @@ int main(int argc, char **argv) { << "input: " << input_node << std::endl << "output: " << output_node << std::endl << "input_shape: " << input_shape << std::endl + << "output_shape: " << output_shape << std::endl << "input_file: " << input_file << std::endl << "output_file: " << output_file << std::endl << "device: " << device << std::endl << "round: " << round << std::endl; - vector shape; - ParseShape(input_shape, &shape); + vector input_shape_vec; + vector output_shape_vec; + ParseShape(input_shape, &input_shape_vec); + ParseShape(output_shape, &output_shape_vec); // load model int64_t t0 = utils::NowMicros(); @@ -158,9 +164,12 @@ int main(int argc, char **argv) { DeviceType device_type = ParseDeviceType(device); VLOG(1) << "Device Type" << device_type; - int64_t input_size = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies()); + int64_t input_size = std::accumulate(input_shape_vec.begin(), + input_shape_vec.end(), 1, std::multiplies()); + int64_t output_size = std::accumulate(output_shape_vec.begin(), + output_shape_vec.end(), 1, std::multiplies()); std::unique_ptr input_data(new float[input_size]); + std::unique_ptr output_data(new float[output_size]); // load input ifstream in_file(input_file, ios::in | ios::binary); @@ -178,10 +187,9 @@ int main(int argc, char **argv) { LOG(INFO) << "Total init duration: " << init_micros << " us"; - std::vector output_shape; VLOG(0) << "Warm up"; t0 = utils::NowMicros(); - engine.Run(input_data.get(), shape, output_shape); + engine.Run(input_data.get(), input_shape_vec, output_data.get()); t1 = utils::NowMicros(); LOG(INFO) << "1st warm up run duration: " << t1 - t0 << " us"; @@ -190,7 +198,7 @@ int main(int argc, char **argv) { t0 = utils::NowMicros(); struct mallinfo prev = mallinfo(); for (int i = 0; i < round; ++i) { - engine.Run(input_data.get(), shape, output_shape); + engine.Run(input_data.get(), input_shape_vec, output_data.get()); if (malloc_check_cycle >= 1 && i % malloc_check_cycle == 0) { LOG(INFO) << "=== check malloc info change #" << i << " ==="; prev = LogMallinfoChange(prev); @@ -200,21 +208,12 @@ int main(int argc, char **argv) { LOG(INFO) << "Avg duration: " << (t1 - t0) / round << " us"; } - const float *output = engine.Run(input_data.get(), shape, output_shape); - if (output != nullptr) { + MACE_CHECK(engine.Run(input_data.get(), input_shape_vec, output_data.get())); + if (output_data != nullptr) { ofstream out_file(output_file, ios::binary); - int64_t output_size = - std::accumulate(output_shape.begin(), output_shape.end(), 1, - std::multiplies()); - out_file.write((const char *)(output), output_size * sizeof(float)); + out_file.write((const char *) (output_data.get()), + output_size * sizeof(float)); out_file.flush(); out_file.close(); - stringstream ss; - ss << "Output shape: ["; - for (auto i : output_shape) { - ss << i << ", "; - } - ss << "]"; - VLOG(0) << ss.str(); } } diff --git a/tools/validate_gcn.sh b/tools/validate_gcn.sh index 645df8b586e255394900dbb6f6b5fa2b9e6f7675..a0f8e580ba174664e53dada6f9f0b83a3466c6d6 100644 --- a/tools/validate_gcn.sh +++ b/tools/validate_gcn.sh @@ -69,6 +69,7 @@ build_and_run() --input=mace_input_node \ --output=mace_output_node \ --input_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},3"\ + --output_shape="1,${IMAGE_SIZE},${IMAGE_SIZE},2"\ --input_file=${PHONE_DATA_DIR}/${INPUT_FILE_NAME} \ --output_file=${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} \ --device=OPENCL \