未验证 提交 9bea65cb 编写于 作者: Y ysh329 提交者: GitHub

[cherry-pick][BugFix][OPENCL] BugFix for OpenCL: image memory malloc; dropout...

[cherry-pick][BugFix][OPENCL] BugFix for OpenCL: image memory malloc; dropout kernel register; precision profiler enhance; layout pass bugfix for opencl (#4426)
上级 2f87a652
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include "lite/api/paddle_place.h"
#include "lite/core/target_wrapper.h"
......@@ -140,20 +141,21 @@ class Buffer {
#ifdef LITE_WITH_OPENCL
template <typename T>
void ResetLazyImage2D(TargetType target,
const size_t img_w,
const size_t img_h,
const size_t img_w_req,
const size_t img_h_req,
void* host_ptr = nullptr) {
if (target != target_ || cl_image2d_width_ < img_w ||
cl_image2d_height_ < img_h || host_ptr != nullptr) {
if (target != target_ || cl_image2d_width_ < img_w_req ||
cl_image2d_height_ < img_h_req || host_ptr != nullptr) {
CHECK_EQ(own_data_, true) << "Can not reset unowned buffer.";
cl_image2d_width_ = std::max(cl_image2d_width_, img_w_req);
cl_image2d_height_ = std::max(cl_image2d_height_, img_h_req);
Free();
data_ = TargetWrapperCL::MallocImage<T>(img_w, img_h, host_ptr);
data_ = TargetWrapperCL::MallocImage<T>(
cl_image2d_width_, cl_image2d_height_, host_ptr);
target_ = target;
space_ = sizeof(T) * img_w * img_h *
space_ = sizeof(T) * cl_image2d_width_ * cl_image2d_height_ *
4; // un-used for opencl Image2D, 4 for RGBA,
cl_use_image2d_ = true;
cl_image2d_width_ = img_w;
cl_image2d_height_ = img_h;
}
}
#endif
......
......@@ -28,6 +28,12 @@ TEST(memory, test) {
ASSERT_TRUE(buf_cuda);
TargetFree(TARGET(kCUDA), buf_cuda);
#endif
#ifdef LITE_WITH_OPENCL
auto* buf_cl = TargetMalloc(TARGET(kOpenCL), 10);
ASSERT_TRUE(buf_cl);
TargetFree(TARGET(kOpenCL), buf_cl);
#endif
}
} // namespace lite
......
......@@ -82,8 +82,11 @@ void TypeLayoutTransformPass::ComplementInputs(SSAGraph* graph,
// not a good judge, but don't find the source of this issue from
// static_pick_kernel_pass
// to this pass.
auto is_host = [](TargetType x) -> bool {
return x == TARGET(kHost) || x == TARGET(kX86) || x == TARGET(kARM);
};
auto* in_arg_type = const_cast<Type*>(in->AsArg().type);
if (in_arg_type->target() == TARGET(kARM) &&
if (is_host(in_arg_type->target()) &&
in_arg_type->layout() == DATALAYOUT(kImageDefault)) {
return;
}
......
......@@ -80,97 +80,98 @@ class Optimizer {
InitControlFlowOpUnusedInputsAndOutputsEliminatePass();
if (passes.empty() || passes.size() == 1) {
std::vector<std::string> passes_local{
{"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", //
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
"lite_var_conv_2d_activation_fuse_pass", //
"lite_match_matrix_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
"lite_scales_fuse_pass", //
"lite_sequence_reverse_embedding_fuse_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", //
std::vector<std::string> passes_local{{
"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", //
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
"lite_var_conv_2d_activation_fuse_pass", //
"lite_match_matrix_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
"lite_scales_fuse_pass", //
"lite_sequence_reverse_embedding_fuse_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_activation_fuse_pass", //
"lite_elementwise_activation_fuse_pass", //
#endif
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__conv2d_fuse_pass",
"__xpu__conv2d_link_previous_out_max_pass",
"__xpu__sfa_head_meanstd_fuse_pass",
"__xpu__sfa_head_moment_fuse_pass",
"__xpu__mmdnn_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
"quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer
// the output scale and
// fix the attribute
// 'enable_int8' for all
// of the quantized ops.
"npu_subgraph_pass",
"huawei_ascend_npu_subgraph_pass",
"xpu_subgraph_pass",
"bm_subgraph_pass",
"apu_subgraph_pass",
"rknpu_subgraph_pass",
"mlu_subgraph_pass",
"control_flow_op_unused_inputs_and_outputs_eliminate_pass",
"static_kernel_pick_pass", // pick original kernel from graph
"remove_tf_redundant_ops_pass",
"variable_place_inference_pass", // inference arg/var's
"mlu_postprocess_pass",
// info(target/precision/layout/device)
// using kernel info
"argument_type_display_pass", // debug pass: show arg-type-node's
// info
// (target/precision/layout/device)
"type_target_cast_pass", // add io_copy/io_copy_once if meet
// different targets when last and next
// node
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_precision_cast_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_layout_cast_pass", // add layout/layout_once op if meet
// different layout when last and next node
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass",
"runtime_context_assign_pass",
"argument_type_display_pass",
"lite_reshape_fuse_pass",
"memory_optimize_pass"}};
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass",
"__xpu__resnet_cbam_fuse_pass",
"__xpu__conv2d_fuse_pass",
"__xpu__conv2d_link_previous_out_max_pass",
"__xpu__sfa_head_meanstd_fuse_pass",
"__xpu__sfa_head_moment_fuse_pass",
"__xpu__mmdnn_fuse_pass",
"__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass",
"quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer
// the output scale and
// fix the attribute
// 'enable_int8' for all
// of the quantized ops.
"npu_subgraph_pass",
"huawei_ascend_npu_subgraph_pass",
"xpu_subgraph_pass",
"bm_subgraph_pass",
"apu_subgraph_pass",
"rknpu_subgraph_pass",
"mlu_subgraph_pass",
"control_flow_op_unused_inputs_and_outputs_eliminate_pass",
"static_kernel_pick_pass", // pick original kernel from graph
"remove_tf_redundant_ops_pass",
"variable_place_inference_pass", // inference arg/var's
"mlu_postprocess_pass",
// info(target/precision/layout/device)
// using kernel info
"argument_type_display_pass", // debug pass: show arg-type-node's
// info
// (target/precision/layout/device)
"type_target_cast_pass", // add io_copy/io_copy_once if meet
// different targets when last and next
// node
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_precision_cast_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_layout_cast_pass", // add layout/layout_once op if meet
// different layout when last and next node
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass",
"runtime_context_assign_pass",
"argument_type_display_pass",
"lite_reshape_fuse_pass",
"memory_optimize_pass" // you can comment this line when enable
// PRECISION_PROFILE
}};
if (passes.size() == 1) {
// multi_stream_analysis_pass must be in the front of
......
......@@ -18,10 +18,18 @@
* of each kernel.
*/
#pragma once
#include <sys/time.h>
#include <time.h>
#include <cmath>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/program.h"
#include "lite/utils/io.h"
#ifdef LITE_WITH_X86
#include "lite/fluid/float16.h"
#endif
......@@ -40,14 +48,50 @@ namespace paddle {
namespace lite {
namespace profile {
static const std::string get_date_str() {
struct tm tm_time;
time_t timestamp = time(NULL);
localtime_r(&timestamp, &tm_time);
struct timeval tv;
gettimeofday(&tv, NULL);
// print date / time
std::string date_str =
std::to_string(1900 + tm_time.tm_year) +
std::to_string(1 + tm_time.tm_mon) + std::to_string(tm_time.tm_mday) +
'_' + std::to_string(tm_time.tm_hour) + std::to_string(tm_time.tm_min) +
std::to_string(tm_time.tm_sec) + '_' + std::to_string(tv.tv_usec / 1000);
return date_str;
}
inline std::string generate_valid_tensor_name(const std::string& name) {
std::string new_name("");
for (size_t i = 0; i < name.length(); ++i) {
if (name[i] != '/') {
new_name += name[i];
} else {
new_name += "_";
}
}
return new_name;
}
template <typename dtype>
static bool write_tensorfile(const Tensor* tensor, const std::string& locate) {
if (locate.find('/') != std::string::npos) {
return false;
static bool write_tensorfile(
const Tensor* tensor,
const std::string& tensor_name,
const std::string prefix_path = "/storage/emulated/0/") {
std::string new_tensor_name = generate_valid_tensor_name(tensor_name);
if (tensor_name.find('/') != std::string::npos) {
LOG(ERROR) << "--> tensor name is abnormal with '\\':" << tensor_name
<< " !!!, replace with '_'," << new_tensor_name
<< new_tensor_name;
}
FILE* fp = fopen(locate.c_str(), "w");
std::string tensor_save_path = prefix_path + new_tensor_name + ".txt";
FILE* fp = fopen(tensor_save_path.c_str(), "w");
if (fp == nullptr) {
LOG(ERROR) << "file open field " << locate;
LOG(ERROR) << "failed open file " << tensor_save_path;
return false;
} else {
const dtype* data = tensor->data<dtype>();
......@@ -56,19 +100,23 @@ static bool write_tensorfile(const Tensor* tensor, const std::string& locate) {
}
}
fclose(fp);
LOG(INFO) << "write tensor " << tensor_name
<< " to file:" << tensor_save_path;
return true;
}
static bool write_precision_summary_tofile(const std::string& string,
const std::string& log_dir = "") {
if (log_dir == "") {
LOG(INFO) << "The `log_dir` of precision summary file is not set. log_dir:"
<< log_dir;
static bool write_precision_summary_tofile(
const std::string& string, const std::string& summary_log_dir = "") {
if (summary_log_dir == "") {
LOG(INFO) << "The `summary_log_dir` of precision summary file is not set. "
"summary_log_dir:"
<< summary_log_dir;
return false;
}
FILE* fp = fopen(log_dir.c_str(), "a");
FILE* fp = fopen(summary_log_dir.c_str(), "a");
if (fp == nullptr) {
LOG(INFO) << "Open precision summary file:" << log_dir << "failed.";
LOG(INFO) << "Open precision summary file:" << summary_log_dir << "failed.";
return false;
} else {
fprintf(fp, "%s\n", string.c_str());
......@@ -85,7 +133,14 @@ class PrecisionProfiler {
std::string inst_precison_str = GetInstPrecision(inst);
}
PrecisionProfiler() {}
PrecisionProfiler() {
MkDirRecur(log_dir_);
const char* write_to_file_raw =
std::getenv("PADDLELITE_PRECISION_WRITE_TO_FILE");
write_result_to_file_ = (write_to_file_raw && atoi(write_to_file_raw) > 0)
? atoi(write_to_file_raw) > 0
: false;
}
std::string GetSummaryHeader() {
using std::setw;
......@@ -102,9 +157,9 @@ class PrecisionProfiler {
<< " " << setw(15) << left << "std_deviation"
<< " " << setw(15) << left << "ave_grow_rate*" << std::endl;
// write to file with path: `log_dir`
if (log_dir_ != "") {
FILE* fp = fopen(log_dir_.c_str(), "a");
// write to file with path: `summary_log_dir`
if (summary_log_dir_ != "") {
FILE* fp = fopen(summary_log_dir_.c_str(), "a");
std::string header_str{ss.str()};
fprintf(fp, "%s\n", header_str.c_str());
fclose(fp);
......@@ -112,6 +167,18 @@ class PrecisionProfiler {
return ss.str();
}
std::string GetSummaryTail() {
STL::stringstream ss;
ss << "[note]" << std::endl;
ss << "1. `ave_grow_rate`: show the sequence value of tensor when std_dev "
"& mean are same."
<< std::endl;
ss << "2. Enable write each output tensor to file: `export "
"PADDLELITE_PRECISION_WRITE_TO_FILE=1` on ADB command line."
<< std::endl;
return ss.str();
}
template <typename T>
double compute_mean(const T* in, const size_t length) {
double sum = 0.;
......@@ -157,6 +224,17 @@ class PrecisionProfiler {
return false;
}
std::string rename_out_for_mem_reuse_pass(const std::string& old_name) {
if (out_tensor_names_map.find(old_name) == out_tensor_names_map.end()) {
out_tensor_names_map[old_name] = 1;
} else {
++out_tensor_names_map[old_name];
}
std::string new_name =
old_name + "_" + std::to_string(out_tensor_names_map[old_name]);
return new_name;
}
void compute_tensor_precision_info(const Tensor* in,
TargetType target_type,
PrecisionType precision_type,
......@@ -180,7 +258,7 @@ class PrecisionProfiler {
*std_dev =
compute_standard_deviation<float>(ptr, in->numel(), true, *mean);
*ave_grow_rate = compute_average_grow_rate<float>(ptr, in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
write_result_to_file&& write_tensorfile<float>(in, name, log_dir_);
return;
}
case PRECISION(kAny): {
......@@ -189,7 +267,7 @@ class PrecisionProfiler {
*std_dev =
compute_standard_deviation<float>(ptr, in->numel(), true, *mean);
*ave_grow_rate = compute_average_grow_rate<float>(ptr, in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
write_result_to_file&& write_tensorfile<float>(in, name, log_dir_);
return;
}
case PRECISION(kInt8): {
......@@ -198,7 +276,7 @@ class PrecisionProfiler {
*std_dev =
compute_standard_deviation<int8_t>(ptr, in->numel(), true, *mean);
*ave_grow_rate = compute_average_grow_rate<int8_t>(ptr, in->numel());
write_result_to_file&& write_tensorfile<int8_t>(in, name);
write_result_to_file&& write_tensorfile<int8_t>(in, name, log_dir_);
return;
}
case PRECISION(kInt32): {
......@@ -207,7 +285,7 @@ class PrecisionProfiler {
*std_dev = compute_standard_deviation<int32_t>(
ptr, in->numel(), true, *mean);
*ave_grow_rate = compute_average_grow_rate<int32_t>(ptr, in->numel());
write_result_to_file&& write_tensorfile<int32_t>(in, name);
write_result_to_file&& write_tensorfile<int32_t>(in, name, log_dir_);
return;
}
case PRECISION(kInt64): {
......@@ -254,7 +332,14 @@ class PrecisionProfiler {
real_out_v.data(), in->numel(), true, *mean);
*ave_grow_rate = compute_average_grow_rate<float>(real_out_v.data(),
real_out_v.size());
write_result_to_file&& write_tensorfile<float>(in, name);
std::shared_ptr<lite::Tensor> real_out_t(new lite::Tensor);
real_out_t->Resize(in->dims());
float* real_out_data = real_out_t->mutable_data<float>();
memcpy(real_out_data,
real_out_v.data(),
real_out_v.size() * sizeof(float));
write_result_to_file&& write_tensorfile<float>(
real_out_t.get(), name, log_dir_);
return;
}
case DATALAYOUT(kNCHW): {
......@@ -269,7 +354,14 @@ class PrecisionProfiler {
in_data_v.data(), in->numel(), true, *mean);
*ave_grow_rate =
compute_average_grow_rate<float>(in_data_v.data(), in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
std::shared_ptr<lite::Tensor> real_out_t(new lite::Tensor);
real_out_t->Resize(in->dims());
float* real_out_data = real_out_t->mutable_data<float>();
memcpy(real_out_data,
in_data_v.data(),
in_data_v.size() * sizeof(float));
write_result_to_file&& write_tensorfile<float>(
real_out_t.get(), name, log_dir_);
return;
}
default:
......@@ -296,7 +388,7 @@ class PrecisionProfiler {
in_data_v.data(), in->numel(), true, *mean);
*ave_grow_rate =
compute_average_grow_rate<float>(in_data_v.data(), in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
write_result_to_file&& write_tensorfile<float>(in, name, log_dir_);
return;
}
case PRECISION(kInt32): {
......@@ -311,7 +403,7 @@ class PrecisionProfiler {
in_data_v.data(), in->numel(), true, *mean);
*ave_grow_rate =
compute_average_grow_rate<int>(in_data_v.data(), in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
write_result_to_file&& write_tensorfile<float>(in, name, log_dir_);
return;
}
case PRECISION(kInt64): {
......@@ -326,7 +418,7 @@ class PrecisionProfiler {
in_data_v.data(), in->numel(), true, *mean);
*ave_grow_rate =
compute_average_grow_rate<int64_t>(in_data_v.data(), in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
write_result_to_file&& write_tensorfile<float>(in, name, log_dir_);
return;
}
case PRECISION(kFP16): {
......@@ -347,7 +439,7 @@ class PrecisionProfiler {
in_data_v.data(), in->numel(), true, *mean);
*ave_grow_rate =
compute_average_grow_rate<float>(in_data_v.data(), in->numel());
write_result_to_file&& write_tensorfile<float>(in, name);
write_result_to_file&& write_tensorfile<float>(in, name, log_dir_);
return;
}
default:
......@@ -372,12 +464,12 @@ class PrecisionProfiler {
using std::left;
using std::fixed;
STL::stringstream ss;
bool write_result_to_file = false;
VLOG(1) << ">> Running kernel: " << inst->op()->op_info()->Repr()
<< " registered on " << TargetToStr(inst->kernel()->target()) << "/"
<< PrecisionToStr(inst->kernel()->precision()) << "/"
<< DataLayoutToStr(inst->kernel()->layout());
<< DataLayoutToStr(inst->kernel()->layout())
<< ", write_result_to_file_:" << write_result_to_file_;
std::string kernel_repr = inst->op()->op_info()->Repr();
std::string kernel_place = TargetToStr(inst->kernel()->target()) + "/" +
......@@ -404,6 +496,7 @@ class PrecisionProfiler {
std::string mean_str{"unused"};
std::string std_dev_str{"unused"};
std::string ave_grow_rate_str{"unused"};
std::string new_out_name = rename_out_for_mem_reuse_pass(out_name);
if (!is_unused(tout)) {
compute_tensor_precision_info(tout,
......@@ -413,14 +506,14 @@ class PrecisionProfiler {
&mean,
&std_dev,
&ave_grow_rate,
out_name,
write_result_to_file);
new_out_name,
write_result_to_file_);
mean_str = std::to_string(mean);
std_dev_str = std::to_string(std_dev);
ave_grow_rate_str = std::to_string(ave_grow_rate);
}
std::string kernel_info = op_name + ":" + kernel_place;
std::string output_arg_info = out_name + ":" +
std::string output_arg_info = new_out_name + ":" +
TargetToStr(type->target()) + "/" +
PrecisionToStr(type->precision()) +
"/" + DataLayoutToStr(type->layout());
......@@ -441,6 +534,7 @@ class PrecisionProfiler {
std::string mean_str{"unused"};
std::string std_dev_str{"unused"};
std::string ave_grow_rate_str{"unused"};
std::string new_out_name = rename_out_for_mem_reuse_pass(out_name);
if (!is_unused(tout)) {
compute_tensor_precision_info(tout,
......@@ -450,14 +544,14 @@ class PrecisionProfiler {
&mean,
&std_dev,
&ave_grow_rate,
out_name,
write_result_to_file);
new_out_name,
write_result_to_file_);
mean_str = std::to_string(mean);
std_dev_str = std::to_string(std_dev);
ave_grow_rate_str = std::to_string(ave_grow_rate);
}
std::string kernel_info = op_name + ":" + kernel_place;
std::string output_arg_info = out_name + ":" +
std::string output_arg_info = new_out_name + ":" +
TargetToStr(type->target()) + "/" +
PrecisionToStr(type->precision()) +
"/" + DataLayoutToStr(type->layout());
......@@ -471,12 +565,16 @@ class PrecisionProfiler {
}
}
}
write_precision_summary_tofile(ss.str(), log_dir_);
write_precision_summary_tofile(ss.str(), summary_log_dir_);
return ss.str();
}
private:
std::string log_dir_{"/storage/emulated/0/precision.log"};
std::string log_dir_{"/storage/emulated/0/PaddleLite_" + get_date_str() +
"/"};
std::string summary_log_dir_{log_dir_ + "precision_summary.log"};
std::map<std::string, size_t> out_tensor_names_map;
bool write_result_to_file_{false};
};
} // namespace profile
......
......@@ -302,7 +302,9 @@ void RuntimeProgram::Run() {
LOG(INFO) << "\n" << profiler_.Summary(profile::Type::kDispatch, false, 1);
#endif
#ifdef LITE_WITH_PRECISION_PROFILE
LOG(INFO) << "\n" << precision_profiler_summary;
LOG(INFO) << "\n"
<< precision_profiler_summary
<< inst_precision_profiler.GetSummaryTail();
#endif
}
......
......@@ -29,6 +29,21 @@ int64_t ShapeProduction(const shape_t& shape) {
return res;
}
std::string ShapePrint(const std::vector<shape_t>& shapes) {
std::string shapes_str{""};
for (size_t shape_idx = 0; shape_idx < shapes.size(); ++shape_idx) {
auto shape = shapes[shape_idx];
std::string shape_str;
for (auto i : shape) {
shape_str += std::to_string(i) + ",";
}
shapes_str += shape_str;
shapes_str +=
(shape_idx != 0 && shape_idx == shapes.size() - 1) ? "" : " : ";
}
return shapes_str;
}
std::string ShapePrint(const shape_t& shape) {
std::string shape_str{""};
for (auto i : shape) {
......@@ -37,6 +52,37 @@ std::string ShapePrint(const shape_t& shape) {
return shape_str;
}
std::vector<std::string> split_string(const std::string& str_in) {
std::vector<std::string> str_out;
std::string tmp_str = str_in;
while (!tmp_str.empty()) {
size_t next_offset = tmp_str.find(":");
str_out.push_back(tmp_str.substr(0, next_offset));
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return str_out;
}
std::vector<int64_t> get_shape(const std::string& str_shape) {
std::vector<int64_t> shape;
std::string tmp_str = str_shape;
while (!tmp_str.empty()) {
int dim = atoi(tmp_str.data());
shape.push_back(dim);
size_t next_offset = tmp_str.find(",");
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return shape;
}
template <typename T>
double compute_mean(const T* in, const size_t length) {
double sum = 0.;
......@@ -70,7 +116,7 @@ inline double GetCurrentUS() {
}
void RunModel(std::string model_dir,
const shape_t& input_shape,
const std::vector<shape_t>& input_shapes,
size_t repeats,
size_t warmup,
size_t print_output_elem,
......@@ -111,12 +157,19 @@ void RunModel(std::string model_dir,
CreatePaddlePredictor<MobileConfig>(config);
// 3. Prepare input data
std::unique_ptr<Tensor> input_tensor(std::move(predictor->GetInput(0)));
input_tensor->Resize(
{input_shape[0], input_shape[1], input_shape[2], input_shape[3]});
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < ShapeProduction(input_tensor->shape()); ++i) {
data[i] = 1;
std::cout << "input_shapes.size():" << input_shapes.size() << std::endl;
for (int j = 0; j < input_shapes.size(); ++j) {
auto input_tensor = predictor->GetInput(j);
input_tensor->Resize(input_shapes[j]);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (int i = 0; i < input_shapes[j].size(); ++i) {
input_num *= input_shapes[j][i];
}
for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f;
}
}
// 4. Run predictor
......@@ -142,7 +195,7 @@ void RunModel(std::string model_dir,
}
avg_duration = sum_duration / static_cast<float>(repeats);
std::cout << "\n======= benchmark summary =======\n"
<< "input_shape(NCHW):" << ShapePrint(input_shape) << "\n"
<< "input_shape(s) (NCHW):" << ShapePrint(input_shapes) << "\n"
<< "model_dir:" << model_dir << "\n"
<< "warmup:" << warmup << "\n"
<< "repeats:" << repeats << "\n"
......@@ -184,18 +237,19 @@ void RunModel(std::string model_dir,
}
int main(int argc, char** argv) {
shape_t input_shape{1, 3, 224, 224}; // shape_t ==> std::vector<int64_t>
std::vector<std::string> str_input_shapes;
std::vector<shape_t> input_shapes{
{1, 3, 224, 224}}; // shape_t ==> std::vector<int64_t>
int repeats = 10;
int warmup = 10;
int print_output_elem = 0;
if (argc > 2 && argc < 9) {
if (argc > 2 && argc < 6) {
std::cerr << "usage: ./" << argv[0] << "\n"
<< " <naive_buffer_model_dir>\n"
<< " <input_n>\n"
<< " <input_c>\n"
<< " <input_h>\n"
<< " <input_w>\n"
<< " <raw_input_shapes>, eg: 1,3,224,224 for 1 input; "
"1,3,224,224:1,5 for 2 inputs\n"
<< " <repeats>\n"
<< " <warmup>\n"
<< " <print_output>" << std::endl;
......@@ -203,14 +257,19 @@ int main(int argc, char** argv) {
}
std::string model_dir = argv[1];
if (argc >= 9) {
input_shape[0] = atoi(argv[2]);
input_shape[1] = atoi(argv[3]);
input_shape[2] = atoi(argv[4]);
input_shape[3] = atoi(argv[5]);
repeats = atoi(argv[6]);
warmup = atoi(argv[7]);
print_output_elem = atoi(argv[8]);
if (argc >= 6) {
input_shapes.clear();
std::string raw_input_shapes = argv[2];
std::cout << "raw_input_shapes: " << raw_input_shapes << std::endl;
str_input_shapes = split_string(raw_input_shapes);
for (size_t i = 0; i < str_input_shapes.size(); ++i) {
std::cout << "input shape: " << str_input_shapes[i] << std::endl;
input_shapes.push_back(get_shape(str_input_shapes[i]));
}
repeats = atoi(argv[3]);
warmup = atoi(argv[4]);
print_output_elem = atoi(argv[5]);
}
// set arm power mode:
// 0 for big cluster, high performance
......@@ -220,7 +279,7 @@ int main(int argc, char** argv) {
size_t power_mode = 0;
RunModel(
model_dir, input_shape, repeats, warmup, print_output_elem, power_mode);
model_dir, input_shapes, repeats, warmup, print_output_elem, power_mode);
return 0;
}
......@@ -35,7 +35,6 @@ void gen_log(STL::ostream& log_stream_,
const int kMaxLen) {
const int len = strlen(file);
std::string time_str;
struct tm tm_time; // Time of creation of LogMessage
time_t timestamp = time(NULL);
#if defined(_WIN32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册